Prerequisites for CSPRNG (#36631)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36631
Summary of changes
1. Moved random transformation functions to DistributionHelper.h (`uniform_int_from_to_distribution`, `uniform_int_full_range_distribution`, `uniform_int_distribution`) to avoid code duplication between default CPU, CUDA rngs and custom rng extensions
2. Made GeneratorImpl fields protected instead of private
3. Introduced `TORCH_CHECK_IF_NOT_ON_CUDA` that does the same as `TORCH_CHECK` if it is not CUDA/ROCm device
4. To test multiple rng extensions I had to move ops registration to the method `registerOps()`, expose it to python and call it `def setUp(self)`
Test Plan: Imported from OSS
Differential Revision: D21229202
Pulled By: pbelevich
fbshipit-source-id: 6aa3280f2fc3324cf3e748388b5087e3a1e49f23
diff --git a/aten/src/ATen/core/DistributionsHelper.h b/aten/src/ATen/core/DistributionsHelper.h
index edd930f..5607974 100644
--- a/aten/src/ATen/core/DistributionsHelper.h
+++ b/aten/src/ATen/core/DistributionsHelper.h
@@ -9,7 +9,9 @@
#endif
#include <ATen/core/Array.h>
+#include <ATen/core/TransformationHelper.h>
#include <c10/util/Half.h>
+#include <c10/util/BFloat16.h>
#include <c10/util/Optional.h>
#include <c10/macros/Macros.h>
@@ -35,75 +37,92 @@
namespace at {
-// Using VectorType in Box-muller derived distributions to avoid
-// code duplication
+/**
+ * Samples a discrete uniform distribution in the range [base, base+range) of type T
+ */
template <typename T>
-struct VectorType { };
+struct uniform_int_from_to_distribution {
-#if defined(__CUDACC__) || defined(__HIPCC__)
-template <> struct VectorType<half> { using type = at::detail::Array<float, 2>; };
-#endif
-template <> struct VectorType<Half> { using type = at::detail::Array<float, 2>; };
-template <> struct VectorType<float> { using type = at::detail::Array<float, 2>; };
-template <> struct VectorType<double> { using type = at::detail::Array<double, 2>; };
+ C10_HOST_DEVICE inline uniform_int_from_to_distribution(uint64_t range, int64_t base) {
+ range_ = range;
+ base_ = base;
+ }
-template <typename T>
-using vect_type = typename VectorType<T>::type;
+ template <typename RNG>
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
+ if ((
+ std::is_same<T, int64_t>::value ||
+ std::is_same<T, double>::value ||
+ std::is_same<T, float>::value ||
+ std::is_same<T, at::BFloat16>::value) && range_ >= 1ULL << 32)
+ {
+ return uniform_int_from_to_transformation<T>(generator->random64(), range_, base_);
+ } else {
+ return uniform_int_from_to_transformation<T>(generator->random(), range_, base_);
+ }
+ }
-// Using DistAccumType in accumulate types for distributions.
-// Note: Ideally we'd be using ATen/AccumulateType.h but looks
-// like the there is some inconsistency in how accumulate types
-// are mapped currently, e.g. for the cpu side, float is mapped
-// to double.
-template <typename T>
-struct DistAccumType { };
-
-#if defined(__CUDACC__) || defined(__HIPCC__)
-template <> struct DistAccumType<half> { using type = float; };
-#endif
-template <> struct DistAccumType<Half> { using type = float; };
-template <> struct DistAccumType<float> { using type = float; };
-template <> struct DistAccumType<double> { using type = double; };
-
-template <typename T>
-using dist_acctype = typename DistAccumType<T>::type;
-
-// Constants for uniform distribution
-// doubles have 52 bits of mantissa (fractional part)
-constexpr uint64_t DOUBLE_MASK = (1ULL << 53) - 1;
-constexpr double DOUBLE_DIVISOR = 1.0 / (1ULL << 53);
-
-// floats have 23 bits of mantissa (fractional part)
-constexpr uint32_t FLOAT_MASK = (1 << 24) - 1;
-constexpr float FLOAT_DIVISOR = 1.0f / (1 << 24);
+ private:
+ uint64_t range_;
+ int64_t base_;
+};
/**
- * Samples a uniform distribution in the range [0,1) of type T
+ * Samples a discrete uniform distribution in the range [min_value(int64_t), max_value(int64_t)]
+ */
+template <typename T>
+struct uniform_int_full_range_distribution {
+
+ template <typename RNG>
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
+ return uniform_int_full_range_transformation<T>(generator->random64());
+ }
+
+};
+
+/**
+ * Samples a discrete uniform distribution in the range [0, max_value(T)] for integral types
+ * and [0, 2^mantissa] for floating-point types.
+ */
+template <typename T>
+struct uniform_int_distribution {
+
+ template <typename RNG>
+ C10_HOST_DEVICE inline T operator()(RNG generator) {
+ if (std::is_same<T, double>::value || std::is_same<T, int64_t>::value) {
+ return uniform_int_transformation<T>(generator->random64());
+ } else {
+ return uniform_int_transformation<T>(generator->random());
+ }
+ }
+
+};
+
+/**
+ * Samples a uniform distribution in the range [from, to) of type T
*/
template <typename T>
struct uniform_real_distribution {
- inline uniform_real_distribution(T a_in, T b_in) {
- TORCH_CHECK(a_in <= b_in);
- TORCH_CHECK(b_in-a_in <= std::numeric_limits<T>::max());
- a = a_in;
- b = b_in;
+ C10_HOST_DEVICE inline uniform_real_distribution(T from, T to) {
+ TORCH_CHECK_IF_NOT_ON_CUDA(from <= to);
+ TORCH_CHECK_IF_NOT_ON_CUDA(to - from <= std::numeric_limits<T>::max());
+ from_ = from;
+ to_ = to;
}
template <typename RNG>
- inline dist_acctype<T> operator()(RNG* generator){
- dist_acctype<T> x;
+ C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){
if(std::is_same<T, double>::value) {
- x = (generator->random64() & DOUBLE_MASK) * DOUBLE_DIVISOR;
+ return uniform_real_transformation<T>(generator->random64(), from_, to_);
} else {
- x = (generator->random() & FLOAT_MASK) * FLOAT_DIVISOR;
+ return uniform_real_transformation<T>(generator->random(), from_, to_);
}
- return (x * (b - a) + a);
}
private:
- T a;
- T b;
+ T from_;
+ T to_;
};
/**
@@ -116,14 +135,15 @@
struct normal_distribution {
inline normal_distribution(T mean_in, T stdv_in) {
- TORCH_CHECK(stdv_in > 0);
+ TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in > 0);
mean = mean_in;
stdv = stdv_in;
}
template <typename RNG>
- inline dist_acctype<T> operator()(RNG* generator){
+ inline dist_acctype<T> operator()(RNG generator){
dist_acctype<T> ret;
+#if !defined(__CUDACC__) && !defined(__HIPCC__)
// return cached values if available
if (std::is_same<T, double>::value) {
if (generator->next_double_normal_sample()) {
@@ -140,12 +160,14 @@
return ret;
}
}
+#endif
// otherwise generate new normal values
uniform_real_distribution<T> uniform(0.0, 1.0);
const dist_acctype<T> u1 = uniform(generator);
const dist_acctype<T> u2 = uniform(generator);
const dist_acctype<T> r = ::sqrt(static_cast<T>(-2.0) * ::log(static_cast<T>(1.0)-u2));
const dist_acctype<T> theta = static_cast<T>(2.0) * static_cast<T>(M_PI) * u1;
+#if !defined(__CUDACC__) && !defined(__HIPCC__)
if (std::is_same<T, double>::value) {
dist_acctype<double> cache = r * ::sin(theta);
generator->set_next_double_normal_sample(c10::optional<double>(cache));
@@ -153,6 +175,7 @@
dist_acctype<float> cache = r * ::sin(theta);
generator->set_next_float_normal_sample(c10::optional<float>(cache));
}
+#endif
ret = r * ::cos(theta) * stdv + mean;
return ret;
}
@@ -174,7 +197,7 @@
}
template <typename RNG>
- inline int operator()(RNG* generator) {
+ inline int operator()(RNG generator) {
uniform_real_distribution<T> uniform(0.0, 1.0);
return uniform(generator) < p;
}
@@ -195,7 +218,7 @@
}
template <typename RNG>
- inline int operator()(RNG* generator) {
+ inline int operator()(RNG generator) {
uniform_real_distribution<T> uniform(0.0, 1.0);
dist_acctype<T> sample = uniform(generator);
return static_cast<int>(::log(static_cast<T>(1.0)-sample) / ::log(p)) + 1;
@@ -216,7 +239,7 @@
}
template <typename RNG>
- __ubsan_ignore_float_divide_by_zero__ inline T operator()(RNG* generator) {
+ __ubsan_ignore_float_divide_by_zero__ inline T operator()(RNG generator) {
uniform_real_distribution<T> uniform(0.0, 1.0);
dist_acctype<T> sample = uniform(generator);
return static_cast<T>(-1.0) / lambda * ::log(static_cast<T>(1.0)-sample);
@@ -238,7 +261,7 @@
}
template <typename RNG>
- inline T operator()(RNG* generator) {
+ inline T operator()(RNG generator) {
uniform_real_distribution<T> uniform(0.0, 1.0);
return median + sigma * ::tan(static_cast<T>(M_PI) * (uniform(generator)-static_cast<T>(0.5)));
}
@@ -263,7 +286,7 @@
}
template<typename RNG>
- inline T operator()(RNG* generator){
+ inline T operator()(RNG generator){
normal_distribution<T> normal(mean, stdv);
return ::exp(normal(generator));
}
diff --git a/aten/src/ATen/core/TransformationHelper.h b/aten/src/ATen/core/TransformationHelper.h
new file mode 100644
index 0000000..7a3dd2a
--- /dev/null
+++ b/aten/src/ATen/core/TransformationHelper.h
@@ -0,0 +1,75 @@
+#include <c10/macros/Macros.h>
+#include <c10/util/Half.h>
+#include <c10/util/BFloat16.h>
+#include <limits>
+#include <cstdint>
+#include <cassert>
+
+namespace at {
+
+// Using DistAccumType in accumulate types for distributions.
+// Note: Ideally we'd be using ATen/AccumulateType.h but looks
+// like the there is some inconsistency in how accumulate types
+// are mapped currently, e.g. for the cpu side, float is mapped
+// to double.
+template <typename T>
+struct DistAccumType { };
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+template <> struct DistAccumType<half> { using type = float; };
+#endif
+template <> struct DistAccumType<Half> { using type = float; };
+template <> struct DistAccumType<float> { using type = float; };
+template <> struct DistAccumType<double> { using type = double; };
+
+template <typename T>
+using dist_acctype = typename DistAccumType<T>::type;
+
+/**
+ * A transformation function for `torch.Tensor.random_()`, when both `from` and `to` are specified.
+ * `range` is `to - from`
+ * `base` is `from`
+ */
+template <typename T, typename V>
+C10_HOST_DEVICE inline T uniform_int_from_to_transformation(V val, uint64_t range, int64_t base) {
+ return static_cast<T>(static_cast<int64_t>((val % range) + base));
+}
+
+/**
+ * A transformation function for `torch.Tensor.random_()`, when `from=min_value(int64_t)` and to=None
+ */
+template <typename T, typename V>
+C10_HOST_DEVICE inline T uniform_int_full_range_transformation(V val) {
+ return static_cast<T>(static_cast<int64_t>(val));
+}
+
+/**
+ * A transformation function for `torch.Tensor.random_()`, when used without specifing `from` and `to`.
+ */
+template <typename T, typename V>
+C10_HOST_DEVICE inline T uniform_int_transformation(V val) {
+ if (std::is_same<T, bool>::value) {
+ return static_cast<bool>(val & 1);
+ } else if (std::is_same<T, double>::value) {
+ return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
+ } else if (std::is_same<T, int64_t>::value) {
+ return static_cast<T>(val % (static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1));
+ } else if (std::is_floating_point<T>::value || std::is_same<T, at::Half>::value || std::is_same<T, at::BFloat16>::value) {
+ return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
+ } else if (std::is_integral<T>::value) {
+ return static_cast<T>(val % (static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1));
+ } else {
+ assert(false);
+ return 0;
+ }
+}
+
+template <typename T, typename V>
+C10_HOST_DEVICE inline dist_acctype<T> uniform_real_transformation(V val, T from, T to) {
+ constexpr auto MASK = static_cast<V>((static_cast<uint64_t>(1) << std::numeric_limits<T>::digits) - 1);
+ constexpr auto DIVISOR = static_cast<T>(1) / (static_cast<uint64_t>(1) << std::numeric_limits<T>::digits);
+ dist_acctype<T> x = (val & MASK) * DIVISOR;
+ return (x * (to - from) + from);
+}
+
+} // namespace at
diff --git a/aten/src/ATen/native/cpu/DistributionTemplates.h b/aten/src/ATen/native/cpu/DistributionTemplates.h
index c78e06f..38b88a4 100644
--- a/aten/src/ATen/native/cpu/DistributionTemplates.h
+++ b/aten/src/ATen/native/cpu/DistributionTemplates.h
@@ -23,20 +23,10 @@
void random_from_to_kernel(TensorIterator& iter, uint64_t range, int64_t base, RNG generator) {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_from_to_kernel_cpu", [&] {
std::lock_guard<std::mutex> lock(generator->mutex_);
- if ((
- std::is_same<scalar_t, int64_t>::value ||
- std::is_same<scalar_t, double>::value ||
- std::is_same<scalar_t, float>::value ||
- std::is_same<scalar_t, at::BFloat16>::value) && range >= 1ULL << 32)
- {
- cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
- return static_cast<scalar_t>(static_cast<int64_t>((generator->random64() % range) + base));
- });
- } else {
- cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
- return static_cast<scalar_t>(static_cast<int64_t>((generator->random() % range) + base));
- });
- }
+ cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
+ uniform_int_from_to_distribution<scalar_t> random(range, base);
+ return random(generator);
+ });
});
}
@@ -52,7 +42,8 @@
std::is_same<scalar_t, float>::value ||
std::is_same<scalar_t, at::BFloat16>::value) {
cpu_serial_kernel(iter, [generator]() -> scalar_t {
- return static_cast<scalar_t>(static_cast<int64_t>(generator->random64()));
+ uniform_int_full_range_distribution<scalar_t> random;
+ return random(generator);
});
} else {
TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16");
@@ -73,37 +64,12 @@
template<typename RNG>
void random_kernel(TensorIterator& iter, RNG generator) {
std::lock_guard<std::mutex> lock(generator->mutex_);
- if (isFloatingType(iter.dtype())) {
- AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_kernel_fp_cpu", [&] {
- if (std::is_same<scalar_t, double>::value) {
- cpu_serial_kernel(iter, [generator]() -> scalar_t {
- return static_cast<scalar_t>(generator->random64() % static_cast<uint64_t>((1ULL << std::numeric_limits<scalar_t>::digits) + 1));
- });
- } else {
- cpu_serial_kernel(iter, [generator]() -> scalar_t {
- return static_cast<scalar_t>(generator->random() % static_cast<uint64_t>((1ULL << std::numeric_limits<scalar_t>::digits) + 1));
- });
- }
+ AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] {
+ cpu_serial_kernel(iter, [generator]() -> scalar_t {
+ uniform_int_distribution<scalar_t> random;
+ return random(generator);
});
- } else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) {
- AT_DISPATCH_INTEGRAL_TYPES_AND(at::ScalarType::Bool, iter.dtype(), "random_kernel_int_cpu", [&] {
- if (std::is_same<scalar_t, int64_t>::value) {
- cpu_serial_kernel(iter, [generator]() -> scalar_t {
- return static_cast<scalar_t>(generator->random64() % (static_cast<uint64_t>(std::numeric_limits<scalar_t>::max()) + 1));
- });
- } else if (std::is_same<scalar_t, bool>::value) {
- cpu_serial_kernel(iter, [generator]() -> scalar_t {
- return static_cast<scalar_t>(generator->random() & 1);
- });
- } else {
- cpu_serial_kernel(iter, [generator]() -> scalar_t {
- return static_cast<scalar_t>(generator->random() % (static_cast<uint64_t>(std::numeric_limits<scalar_t>::max()) + 1));
- });
- }
- });
- } else {
- TORCH_CHECK(false, "random_kernel_cpu handles only integral, floating-point and boolean types");
- }
+ });
}
template<typename RNG>
diff --git a/aten/src/ATen/native/cuda/DistributionTemplates.h b/aten/src/ATen/native/cuda/DistributionTemplates.h
index 19b86de..4d768cb 100644
--- a/aten/src/ATen/native/cuda/DistributionTemplates.h
+++ b/aten/src/ATen/native/cuda/DistributionTemplates.h
@@ -8,6 +8,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/detail/FunctionTraits.h>
+#include <ATen/core/DistributionsHelper.h>
#include <curand.h>
#include <curand_kernel.h>
@@ -286,7 +287,7 @@
{
// define lambda to mod with range and add base
auto random_func = [range, base] __device__ (uint64_t rand) {
- return static_cast<scalar_t>(static_cast<int64_t>(rand % range + base));
+ return uniform_int_from_to_transformation<scalar_t>(rand, range, base);
};
distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
gen,
@@ -300,7 +301,7 @@
random_func);
} else {
auto random_func = [range, base] __device__ (uint32_t rand) {
- return static_cast<scalar_t>(static_cast<int64_t>(rand % range + base));
+ return uniform_int_from_to_transformation<scalar_t>(rand, range, base);
};
distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
gen,
@@ -329,7 +330,7 @@
std::is_same<scalar_t, float>::value ||
std::is_same<scalar_t, at::BFloat16>::value) {
auto random_func = [] __device__ (uint64_t rand) {
- return static_cast<scalar_t>(static_cast<int64_t>(rand));
+ return uniform_int_full_range_transformation<scalar_t>(rand);
};
distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
gen,
@@ -365,75 +366,32 @@
TORCH_CHECK(false, "random_() is not supported for bfloat16 CUDA tensors on Windows. Please see https://github.com/pytorch/pytorch/issues/33793");
}
#endif
- if (isFloatingType(iter.dtype())) {
- AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_kernel_fp_cuda", [&] {
- if (std::is_same<scalar_t, double>::value) {
- auto random_func = [] __device__ (uint64_t rand) {
- return static_cast<scalar_t>(rand % static_cast<uint64_t>((1ULL << std::numeric_limits<scalar_t>::digits) + 1));
- };
- distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
- gen,
- [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
- ulonglong2 ret;
- uint4 rand_val = curand4(state);
- ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
- ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
- return ret;
- },
- random_func);
- } else {
- auto random_func = [] __device__ (uint32_t rand) {
- return static_cast<scalar_t>(rand % static_cast<uint64_t>((1ULL << std::numeric_limits<scalar_t>::digits) + 1));
- };
- distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
- gen,
- [] __device__ (curandStatePhilox4_32_10_t* state) {
- return curand4(state);
- },
- random_func);
- }
- });
- } else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) {
- AT_DISPATCH_INTEGRAL_TYPES_AND(at::ScalarType::Bool, iter.dtype(), "random_kernel_int_cuda", [&] {
- if (std::is_same<scalar_t, int64_t>::value) {
- auto random_func = [] __device__ (uint64_t rand) {
- return static_cast<scalar_t>(rand % (static_cast<uint64_t>(std::numeric_limits<scalar_t>::max()) + 1));
- };
- distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
- gen,
- [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
- ulonglong2 ret;
- uint4 rand_val = curand4(state);
- ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
- ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
- return ret;
- },
- random_func);
- } else if (std::is_same<scalar_t, bool>::value) {
- auto random_func = [] __device__ (uint32_t rand) {
- return static_cast<scalar_t>(rand & 1);
- };
- distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
- gen,
- [] __device__ (curandStatePhilox4_32_10_t* state) {
- return curand4(state);
- },
- random_func);
- } else {
- auto random_func = [] __device__ (uint32_t rand) {
- return static_cast<scalar_t>(rand % (static_cast<uint64_t>(std::numeric_limits<scalar_t>::max()) + 1));
- };
- distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
- gen,
- [] __device__ (curandStatePhilox4_32_10_t* state) {
- return curand4(state);
- },
- random_func);
- }
- });
- } else {
- TORCH_CHECK(false, "random_kernel_cuda handles only integral, floating-point and boolean types");
- }
+ AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] {
+ if (std::is_same<scalar_t, double>::value || std::is_same<scalar_t, int64_t>::value) {
+ auto random_func = [] __device__ (uint64_t rand) {
+ return uniform_int_transformation<scalar_t>(rand);
+ };
+ distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter, gen,
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
+ ulonglong2 ret;
+ uint4 rand_val = curand4(state);
+ ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
+ ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
+ return ret;
+ },
+ random_func);
+ } else {
+ auto random_func = [] __device__ (uint32_t rand) {
+ return uniform_int_transformation<scalar_t>(rand);
+ };
+ distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
+ gen,
+ [] __device__ (curandStatePhilox4_32_10_t* state) {
+ return curand4(state);
+ },
+ random_func);
+ }
+ });
}
template<typename RNG>
diff --git a/c10/core/GeneratorImpl.h b/c10/core/GeneratorImpl.h
index 9889625..fff105a 100644
--- a/c10/core/GeneratorImpl.h
+++ b/c10/core/GeneratorImpl.h
@@ -86,7 +86,7 @@
return pyobj_;
}
- private:
+ protected:
Device device_;
DispatchKeySet key_set_;
PyObject* pyobj_ = nullptr;
diff --git a/c10/util/Exception.h b/c10/util/Exception.h
index 39adf92..54086b7 100644
--- a/c10/util/Exception.h
+++ b/c10/util/Exception.h
@@ -272,6 +272,15 @@
#endif
#define TORCH_CHECK(cond, ...) TORCH_CHECK_WITH(Error, cond, __VA_ARGS__)
+// An utility macro that does what `TORCH_CHECK` does if compiled in the host code,
+// otherwise does nothing. Supposed to be used in the code shared between host and
+// device code as an alternative for `TORCH_CHECK`.
+#if defined(__CUDACC__) || defined(__HIPCC__)
+#define TORCH_CHECK_IF_NOT_ON_CUDA(cond, ...)
+#else
+#define TORCH_CHECK_IF_NOT_ON_CUDA(cond, ...) TORCH_CHECK(cond, __VA_ARGS__)
+#endif
+
// Debug only version of TORCH_INTERNAL_ASSERT. This macro only checks in debug
// build, and does nothing in release build. It is appropriate to use
// in situations where you want to add an assert to a hotpath, but it is
diff --git a/test/cpp_extensions/rng_extension.cpp b/test/cpp_extensions/rng_extension.cpp
index ba74173..c16e35e 100644
--- a/test/cpp_extensions/rng_extension.cpp
+++ b/test/cpp_extensions/rng_extension.cpp
@@ -53,7 +53,8 @@
return instance_count;
}
-static auto registry = torch::RegisterOperators()
+void registerOps() {
+ static auto registry = torch::RegisterOperators()
.op(torch::RegisterOperators::options()
.schema("aten::random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)")
.impl_unboxedOnlyKernel<decltype(random_from_to), &random_from_to>(DispatchKey::CustomRNGKeyId))
@@ -63,8 +64,10 @@
.op(torch::RegisterOperators::options()
.schema("aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)")
.impl_unboxedOnlyKernel<decltype(random_), &random_>(DispatchKey::CustomRNGKeyId));
+}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("registerOps", ®isterOps);
m.def("createTestCPUGenerator", &createTestCPUGenerator);
m.def("getInstanceCount", &getInstanceCount);
m.def("identity", &identity);
diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py
index 0cc5e6f..c860e95 100644
--- a/test/test_cpp_extensions_aot.py
+++ b/test/test_cpp_extensions_aot.py
@@ -140,6 +140,10 @@
class TestRNGExtension(common.TestCase):
+ def setUp(self):
+ super(TestRNGExtension, self).setUp()
+ rng_extension.registerOps()
+
def test_rng(self):
fourty_two = torch.full((10,), 42, dtype=torch.int64)