[jiterator] sqrt-rsqrt : complex

As per title

Relies on existing OpInfo tests from `test_ops.py` and `test_unary_ufuncs.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73781
Approved by: https://github.com/anjali411
diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu
index 064e0da..3031706 100644
--- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu
+++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu
@@ -80,19 +80,45 @@
 
 // We manually overload rsqrt because std::rsqrt does not work with complex types.
 template<typename scalar_t>
-__host__ __device__ static inline scalar_t rsqrt_wrapper(scalar_t v) {
+C10_HOST_DEVICE static inline scalar_t rsqrt_wrapper(scalar_t v) {
   return ::rsqrt(v);
 }
 
 template<typename T>
-__host__ __device__ static inline c10::complex<T> rsqrt_wrapper(c10::complex<T> v) {
+C10_HOST_DEVICE static inline c10::complex<T> rsqrt_wrapper(c10::complex<T> v) {
   const c10::complex<T> one = c10::complex<T>(1.0, 0);
   // std::sqrt for c10::complex is overloaded in c10/util/complex_math.h
   return one / ::sqrt(v);
 }
 
+const char rsqrt_name[] = "rsqrt_kernel";
 void rsqrt_kernel_cuda(TensorIteratorBase& iter) {
-  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
+  auto common_dtype = iter.common_dtype();
+  if (at::isComplexType(common_dtype)) {
+    #if AT_USE_JITERATOR()
+      static const auto rsqrt_string = jiterator_stringify(
+          template <typename T>
+          T rsqrt_kernel(T x) {
+            const T one = T{1};
+            return one / std::sqrt(x);
+      }); // rsqrt_string
+      AT_DISPATCH_COMPLEX_TYPES(common_dtype, "rsqrt_cuda", [&]() {
+          jitted_gpu_kernel<
+              /*name=*/rsqrt_name,
+              /*return_dtype=*/scalar_t,
+              /*common_dtype=*/scalar_t,
+              /*arity=*/1>(iter, rsqrt_string);
+      });
+    #else
+      AT_DISPATCH_COMPLEX_TYPES(common_dtype, "rsqrt_cuda", [&]() {
+        gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
+          // In CUDA, ::rsqrt is overloaded for float and at::Half here is implicitly cast to float.
+          return rsqrt_wrapper(a);
+        });
+      });
+    #endif
+  } else {
+    AT_DISPATCH_FLOATING_TYPES_AND2(
       ScalarType::BFloat16, ScalarType::Half,
       iter.common_dtype(), "rsqrt_cuda",
       [&]() {
@@ -101,14 +127,40 @@
           return rsqrt_wrapper(a);
         });
       });
+  }
 }
 
+const char sqrt_name[] = "sqrt_kernel";
 void sqrt_kernel_cuda(TensorIteratorBase& iter) {
-  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "sqrt_cuda", [&]() {
-    gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
-      return ::sqrt(a);
+  auto common_dtype = iter.common_dtype();
+  if (at::isComplexType(common_dtype)) {
+    #if AT_USE_JITERATOR()
+      static const auto sqrt_string = jiterator_stringify(
+          template <typename T>
+          T sqrt_kernel(T x) {
+            return std::sqrt(x);
+      }); // sqrt_string
+      AT_DISPATCH_COMPLEX_TYPES(common_dtype, "sqrt_cuda", [&]() {
+          jitted_gpu_kernel<
+              /*name=*/sqrt_name,
+              /*return_dtype=*/scalar_t,
+              /*common_dtype=*/scalar_t,
+              /*arity=*/1>(iter, sqrt_string);
+      });
+    #else
+      AT_DISPATCH_COMPLEX_TYPES(common_dtype, "sqrt_cuda", [&]() {
+        gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
+          return std::sqrt(a);
+        });
+      });
+    #endif
+  } else {
+    AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, common_dtype, "sqrt_cuda", [&]() {
+      gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
+        return std::sqrt(a);
+      });
     });
-  });
+  }
 }
 
 void clamp_kernel_cuda(TensorIteratorBase& iter, const Scalar& min_value, const Scalar& max_value) {