[jiterator] sqrt-rsqrt : complex
As per title
Relies on existing OpInfo tests from `test_ops.py` and `test_unary_ufuncs.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73781
Approved by: https://github.com/anjali411
diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu
index 064e0da..3031706 100644
--- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu
+++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu
@@ -80,19 +80,45 @@
// We manually overload rsqrt because std::rsqrt does not work with complex types.
template<typename scalar_t>
-__host__ __device__ static inline scalar_t rsqrt_wrapper(scalar_t v) {
+C10_HOST_DEVICE static inline scalar_t rsqrt_wrapper(scalar_t v) {
return ::rsqrt(v);
}
template<typename T>
-__host__ __device__ static inline c10::complex<T> rsqrt_wrapper(c10::complex<T> v) {
+C10_HOST_DEVICE static inline c10::complex<T> rsqrt_wrapper(c10::complex<T> v) {
const c10::complex<T> one = c10::complex<T>(1.0, 0);
// std::sqrt for c10::complex is overloaded in c10/util/complex_math.h
return one / ::sqrt(v);
}
+const char rsqrt_name[] = "rsqrt_kernel";
void rsqrt_kernel_cuda(TensorIteratorBase& iter) {
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
+ auto common_dtype = iter.common_dtype();
+ if (at::isComplexType(common_dtype)) {
+ #if AT_USE_JITERATOR()
+ static const auto rsqrt_string = jiterator_stringify(
+ template <typename T>
+ T rsqrt_kernel(T x) {
+ const T one = T{1};
+ return one / std::sqrt(x);
+ }); // rsqrt_string
+ AT_DISPATCH_COMPLEX_TYPES(common_dtype, "rsqrt_cuda", [&]() {
+ jitted_gpu_kernel<
+ /*name=*/rsqrt_name,
+ /*return_dtype=*/scalar_t,
+ /*common_dtype=*/scalar_t,
+ /*arity=*/1>(iter, rsqrt_string);
+ });
+ #else
+ AT_DISPATCH_COMPLEX_TYPES(common_dtype, "rsqrt_cuda", [&]() {
+ gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
+ // In CUDA, ::rsqrt is overloaded for float and at::Half here is implicitly cast to float.
+ return rsqrt_wrapper(a);
+ });
+ });
+ #endif
+ } else {
+ AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::BFloat16, ScalarType::Half,
iter.common_dtype(), "rsqrt_cuda",
[&]() {
@@ -101,14 +127,40 @@
return rsqrt_wrapper(a);
});
});
+ }
}
+const char sqrt_name[] = "sqrt_kernel";
void sqrt_kernel_cuda(TensorIteratorBase& iter) {
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "sqrt_cuda", [&]() {
- gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
- return ::sqrt(a);
+ auto common_dtype = iter.common_dtype();
+ if (at::isComplexType(common_dtype)) {
+ #if AT_USE_JITERATOR()
+ static const auto sqrt_string = jiterator_stringify(
+ template <typename T>
+ T sqrt_kernel(T x) {
+ return std::sqrt(x);
+ }); // sqrt_string
+ AT_DISPATCH_COMPLEX_TYPES(common_dtype, "sqrt_cuda", [&]() {
+ jitted_gpu_kernel<
+ /*name=*/sqrt_name,
+ /*return_dtype=*/scalar_t,
+ /*common_dtype=*/scalar_t,
+ /*arity=*/1>(iter, sqrt_string);
+ });
+ #else
+ AT_DISPATCH_COMPLEX_TYPES(common_dtype, "sqrt_cuda", [&]() {
+ gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
+ return std::sqrt(a);
+ });
+ });
+ #endif
+ } else {
+ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, common_dtype, "sqrt_cuda", [&]() {
+ gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
+ return std::sqrt(a);
+ });
});
- });
+ }
}
void clamp_kernel_cuda(TensorIteratorBase& iter, const Scalar& min_value, const Scalar& max_value) {