Check if generator has next normal sample cache methods in normal_distribution (#39816)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39816
This change replaces [`#if !defined(__CUDACC__) && !defined(__HIPCC__)`](https://github.com/pytorch/pytorch/blob/856215509d89c935cd1768ce4b496d4fc0e919a6/aten/src/ATen/core/DistributionsHelper.h#L147) with SFINAE expression that checks if RNG typename has next_double_normal_sample, set_next_double_normal_sample, next_float_normal_sample, set_next_float_normal_sample methods
It is required by (and manually tested with) https://github.com/pytorch/csprng/pull/28
Fixes #39618
Test Plan: Imported from OSS
Differential Revision: D22002599
Pulled By: pbelevich
fbshipit-source-id: e33d42a7e88c5729b077b9cdbf1437158dab48bc
diff --git a/aten/src/ATen/core/DistributionsHelper.h b/aten/src/ATen/core/DistributionsHelper.h
index cd08a56..071af83 100644
--- a/aten/src/ATen/core/DistributionsHelper.h
+++ b/aten/src/ATen/core/DistributionsHelper.h
@@ -126,6 +126,67 @@
T to_;
};
+// The SFINAE checks introduced in #39816 looks overcomplicated and must revisited
+// https://github.com/pytorch/pytorch/issues/40052
+#define DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(member) \
+template <typename T> \
+struct has_member_##member \
+{ \
+ typedef char yes; \
+ typedef long no; \
+ template <typename U> static yes test(decltype(&U::member)); \
+ template <typename U> static no test(...); \
+ static constexpr bool value = sizeof(test<T>(0)) == sizeof(yes); \
+}
+
+DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_double_normal_sample);
+DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_double_normal_sample);
+DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_float_normal_sample);
+DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_float_normal_sample);
+
+#define DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(TYPE) \
+ \
+template <typename RNG, typename ret_type, \
+ typename std::enable_if_t<( \
+ has_member_next_##TYPE##_normal_sample<RNG>::value && \
+ has_member_set_next_##TYPE##_normal_sample<RNG>::value \
+ ), int> = 0> \
+C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* generator, ret_type* ret) { \
+ if (generator->next_##TYPE##_normal_sample()) { \
+ *ret = *(generator->next_##TYPE##_normal_sample()); \
+ generator->set_next_##TYPE##_normal_sample(c10::optional<TYPE>()); \
+ return true; \
+ } \
+ return false; \
+} \
+ \
+template <typename RNG, typename ret_type, \
+ typename std::enable_if_t<( \
+ !has_member_next_##TYPE##_normal_sample<RNG>::value || \
+ !has_member_set_next_##TYPE##_normal_sample<RNG>::value \
+ ), int> = 0> \
+C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* generator, ret_type* ret) { \
+ return false; \
+} \
+ \
+template <typename RNG, typename ret_type, \
+ typename std::enable_if_t<( \
+ has_member_set_next_##TYPE##_normal_sample<RNG>::value \
+ ), int> = 0> \
+C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* generator, ret_type cache) { \
+ generator->set_next_##TYPE##_normal_sample(cache); \
+} \
+ \
+template <typename RNG, typename ret_type, \
+ typename std::enable_if_t<( \
+ !has_member_set_next_##TYPE##_normal_sample<RNG>::value \
+ ), int> = 0> \
+C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* generator, ret_type cache) { \
+}
+
+DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(double);
+DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(float);
+
/**
* Samples a normal distribution using the Box-Muller method
* Takes mean and standard deviation as inputs
@@ -144,41 +205,29 @@
template <typename RNG>
C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){
dist_acctype<T> ret;
-#if !defined(__CUDACC__) && !defined(__HIPCC__)
// return cached values if available
if (std::is_same<T, double>::value) {
- if (generator->next_double_normal_sample()) {
- ret = *(generator->next_double_normal_sample()) * stdv + mean;
- // reset c10::optional to null
- generator->set_next_double_normal_sample(c10::optional<double>());
- return ret;
+ if (maybe_get_next_double_normal_sample(generator, &ret)) {
+ return transformation::normal(ret, mean, stdv);
}
} else {
- if (generator->next_float_normal_sample()) {
- ret = *(generator->next_float_normal_sample()) * stdv + mean;
- // reset c10::optional to null
- generator->set_next_float_normal_sample(c10::optional<float>());
- return ret;
+ if (maybe_get_next_float_normal_sample(generator, &ret)) {
+ return transformation::normal(ret, mean, stdv);
}
}
-#endif
// otherwise generate new normal values
uniform_real_distribution<T> uniform(0.0, 1.0);
const dist_acctype<T> u1 = uniform(generator);
const dist_acctype<T> u2 = uniform(generator);
const dist_acctype<T> r = ::sqrt(static_cast<T>(-2.0) * ::log(static_cast<T>(1.0)-u2));
const dist_acctype<T> theta = static_cast<T>(2.0) * static_cast<T>(M_PI) * u1;
-#if !defined(__CUDACC__) && !defined(__HIPCC__)
if (std::is_same<T, double>::value) {
- dist_acctype<double> cache = r * ::sin(theta);
- generator->set_next_double_normal_sample(c10::optional<double>(cache));
+ maybe_set_next_double_normal_sample(generator, r * ::sin(theta));
} else {
- dist_acctype<float> cache = r * ::sin(theta);
- generator->set_next_float_normal_sample(c10::optional<float>(cache));
+ maybe_set_next_float_normal_sample(generator, r * ::sin(theta));
}
-#endif
- ret = transformation::normal(r * ::cos(theta), mean, stdv);
- return ret;
+ ret = r * ::cos(theta);
+ return transformation::normal(ret, mean, stdv);
}
private: