Check if generator has next normal sample cache methods in normal_distribution (#39816)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39816

This change replaces [`#if !defined(__CUDACC__) && !defined(__HIPCC__)`](https://github.com/pytorch/pytorch/blob/856215509d89c935cd1768ce4b496d4fc0e919a6/aten/src/ATen/core/DistributionsHelper.h#L147) with SFINAE expression that checks if RNG typename has next_double_normal_sample, set_next_double_normal_sample, next_float_normal_sample, set_next_float_normal_sample methods

It is required by (and manually tested with) https://github.com/pytorch/csprng/pull/28

Fixes #39618

Test Plan: Imported from OSS

Differential Revision: D22002599

Pulled By: pbelevich

fbshipit-source-id: e33d42a7e88c5729b077b9cdbf1437158dab48bc
diff --git a/aten/src/ATen/core/DistributionsHelper.h b/aten/src/ATen/core/DistributionsHelper.h
index cd08a56..071af83 100644
--- a/aten/src/ATen/core/DistributionsHelper.h
+++ b/aten/src/ATen/core/DistributionsHelper.h
@@ -126,6 +126,67 @@
     T to_;
 };
 
+// The SFINAE checks introduced in #39816 looks overcomplicated and must revisited
+// https://github.com/pytorch/pytorch/issues/40052
+#define DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(member)              \
+template <typename T>                                                \
+struct has_member_##member                                           \
+{                                                                    \
+    typedef char yes;                                                \
+    typedef long no;                                                 \
+    template <typename U> static yes test(decltype(&U::member));     \
+    template <typename U> static no test(...);                       \
+    static constexpr bool value = sizeof(test<T>(0)) == sizeof(yes); \
+}
+
+DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_double_normal_sample);
+DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_double_normal_sample);
+DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_float_normal_sample);
+DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_float_normal_sample);
+
+#define DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(TYPE)                                      \
+                                                                                                    \
+template <typename RNG, typename ret_type,                                                          \
+          typename std::enable_if_t<(                                                               \
+            has_member_next_##TYPE##_normal_sample<RNG>::value &&                                   \
+            has_member_set_next_##TYPE##_normal_sample<RNG>::value                                  \
+          ), int> = 0>                                                                              \
+C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* generator, ret_type* ret) {  \
+  if (generator->next_##TYPE##_normal_sample()) {                                                   \
+    *ret = *(generator->next_##TYPE##_normal_sample());                                             \
+    generator->set_next_##TYPE##_normal_sample(c10::optional<TYPE>());                              \
+    return true;                                                                                    \
+  }                                                                                                 \
+  return false;                                                                                     \
+}                                                                                                   \
+                                                                                                    \
+template <typename RNG, typename ret_type,                                                          \
+          typename std::enable_if_t<(                                                               \
+            !has_member_next_##TYPE##_normal_sample<RNG>::value ||                                  \
+            !has_member_set_next_##TYPE##_normal_sample<RNG>::value                                 \
+          ), int> = 0>                                                                              \
+C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* generator, ret_type* ret) {  \
+  return false;                                                                                     \
+}                                                                                                   \
+                                                                                                    \
+template <typename RNG, typename ret_type,                                                          \
+          typename std::enable_if_t<(                                                               \
+            has_member_set_next_##TYPE##_normal_sample<RNG>::value                                  \
+          ), int> = 0>                                                                              \
+C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* generator, ret_type cache) { \
+  generator->set_next_##TYPE##_normal_sample(cache);                                                \
+}                                                                                                   \
+                                                                                                    \
+template <typename RNG, typename ret_type,                                                          \
+          typename std::enable_if_t<(                                                               \
+            !has_member_set_next_##TYPE##_normal_sample<RNG>::value                                 \
+          ), int> = 0>                                                                              \
+C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* generator, ret_type cache) { \
+}
+
+DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(double);
+DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(float);
+
 /**
  * Samples a normal distribution using the Box-Muller method
  * Takes mean and standard deviation as inputs
@@ -144,41 +205,29 @@
   template <typename RNG>
   C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){
     dist_acctype<T> ret;
-#if !defined(__CUDACC__) && !defined(__HIPCC__)
     // return cached values if available
     if (std::is_same<T, double>::value) {
-      if (generator->next_double_normal_sample()) {
-        ret = *(generator->next_double_normal_sample()) * stdv + mean;
-        // reset c10::optional to null
-        generator->set_next_double_normal_sample(c10::optional<double>());
-        return ret;
+      if (maybe_get_next_double_normal_sample(generator, &ret)) {
+        return transformation::normal(ret, mean, stdv);
       }
     } else {
-      if (generator->next_float_normal_sample()) {
-        ret = *(generator->next_float_normal_sample()) * stdv + mean;
-        // reset c10::optional to null
-        generator->set_next_float_normal_sample(c10::optional<float>());
-        return ret;
+      if (maybe_get_next_float_normal_sample(generator, &ret)) {
+        return transformation::normal(ret, mean, stdv);
       }
     }
-#endif
     // otherwise generate new normal values
     uniform_real_distribution<T> uniform(0.0, 1.0);
     const dist_acctype<T> u1 = uniform(generator);
     const dist_acctype<T> u2 = uniform(generator);
     const dist_acctype<T> r = ::sqrt(static_cast<T>(-2.0) * ::log(static_cast<T>(1.0)-u2));
     const dist_acctype<T> theta = static_cast<T>(2.0) * static_cast<T>(M_PI) * u1;
-#if !defined(__CUDACC__) && !defined(__HIPCC__)
     if (std::is_same<T, double>::value) {
-      dist_acctype<double> cache = r * ::sin(theta);
-      generator->set_next_double_normal_sample(c10::optional<double>(cache));
+      maybe_set_next_double_normal_sample(generator, r * ::sin(theta));
     } else {
-      dist_acctype<float> cache = r * ::sin(theta);
-      generator->set_next_float_normal_sample(c10::optional<float>(cache));
+      maybe_set_next_float_normal_sample(generator, r * ::sin(theta));
     }
-#endif
-    ret = transformation::normal(r * ::cos(theta), mean, stdv);
-    return ret;
+    ret = r * ::cos(theta);
+    return transformation::normal(ret, mean, stdv);
   }
 
   private: