Prerequisites for CSPRNG (#36631)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36631

Summary of changes

1. Moved random transformation functions to DistributionHelper.h (`uniform_int_from_to_distribution`, `uniform_int_full_range_distribution`, `uniform_int_distribution`) to avoid code duplication between default CPU, CUDA rngs and custom rng extensions
2. Made GeneratorImpl fields protected instead of private
3. Introduced `TORCH_CHECK_IF_NOT_ON_CUDA` that does the same as `TORCH_CHECK` if it is not CUDA/ROCm device
4. To test multiple rng extensions I had to move ops registration to the method `registerOps()`, expose it to python and call it `def setUp(self)`

Test Plan: Imported from OSS

Differential Revision: D21229202

Pulled By: pbelevich

fbshipit-source-id: 6aa3280f2fc3324cf3e748388b5087e3a1e49f23
diff --git a/aten/src/ATen/core/DistributionsHelper.h b/aten/src/ATen/core/DistributionsHelper.h
index edd930f..5607974 100644
--- a/aten/src/ATen/core/DistributionsHelper.h
+++ b/aten/src/ATen/core/DistributionsHelper.h
@@ -9,7 +9,9 @@
 #endif
 
 #include <ATen/core/Array.h>
+#include <ATen/core/TransformationHelper.h>
 #include <c10/util/Half.h>
+#include <c10/util/BFloat16.h>
 #include <c10/util/Optional.h>
 #include <c10/macros/Macros.h>
 
@@ -35,75 +37,92 @@
 
 namespace at {
 
-// Using VectorType in Box-muller derived distributions to avoid
-// code duplication
+/**
+ * Samples a discrete uniform distribution in the range [base, base+range) of type T
+ */
 template <typename T>
-struct VectorType {  };
+struct uniform_int_from_to_distribution {
 
-#if defined(__CUDACC__) || defined(__HIPCC__)
-template <> struct VectorType<half> { using type = at::detail::Array<float, 2>; };
-#endif
-template <> struct VectorType<Half> { using type = at::detail::Array<float, 2>; };
-template <> struct VectorType<float> { using type = at::detail::Array<float, 2>; };
-template <> struct VectorType<double> { using type = at::detail::Array<double, 2>; };
+  C10_HOST_DEVICE inline uniform_int_from_to_distribution(uint64_t range, int64_t base) {
+    range_ = range;
+    base_ = base;
+  }
 
-template <typename T>
-using vect_type = typename VectorType<T>::type;
+  template <typename RNG>
+  C10_HOST_DEVICE inline T operator()(RNG generator) {
+    if ((
+      std::is_same<T, int64_t>::value ||
+      std::is_same<T, double>::value ||
+      std::is_same<T, float>::value ||
+      std::is_same<T, at::BFloat16>::value) && range_ >= 1ULL << 32)
+    {
+      return uniform_int_from_to_transformation<T>(generator->random64(), range_, base_);
+    } else {
+      return uniform_int_from_to_transformation<T>(generator->random(), range_, base_);
+    }
+  }
 
-// Using DistAccumType in accumulate types for distributions.
-// Note: Ideally we'd be using ATen/AccumulateType.h but looks
-// like the there is some inconsistency in how accumulate types
-// are mapped currently, e.g. for the cpu side, float is mapped
-// to double.
-template <typename T>
-struct DistAccumType {  };
-
-#if defined(__CUDACC__) || defined(__HIPCC__)
-template <> struct DistAccumType<half> { using type = float; };
-#endif
-template <> struct DistAccumType<Half> { using type = float; };
-template <> struct DistAccumType<float> { using type = float; };
-template <> struct DistAccumType<double> { using type = double; };
-
-template <typename T>
-using dist_acctype = typename DistAccumType<T>::type;
-
-// Constants for uniform distribution
-// doubles have 52 bits of mantissa (fractional part)
-constexpr uint64_t DOUBLE_MASK = (1ULL << 53) - 1;
-constexpr double DOUBLE_DIVISOR = 1.0 / (1ULL << 53);
-
-// floats have 23 bits of mantissa (fractional part)
-constexpr uint32_t FLOAT_MASK = (1 << 24) - 1;
-constexpr float FLOAT_DIVISOR = 1.0f / (1 << 24);
+  private:
+    uint64_t range_;
+    int64_t base_;
+};
 
 /**
- * Samples a uniform distribution in the range [0,1) of type T
+ * Samples a discrete uniform distribution in the range [min_value(int64_t), max_value(int64_t)]
+ */
+template <typename T>
+struct uniform_int_full_range_distribution {
+
+  template <typename RNG>
+  C10_HOST_DEVICE inline T operator()(RNG generator) {
+    return uniform_int_full_range_transformation<T>(generator->random64());
+  }
+
+};
+
+/**
+ * Samples a discrete uniform distribution in the range [0, max_value(T)] for integral types
+ * and [0, 2^mantissa] for floating-point types.
+ */
+template <typename T>
+struct uniform_int_distribution {
+
+  template <typename RNG>
+  C10_HOST_DEVICE inline T operator()(RNG generator) {
+    if (std::is_same<T, double>::value || std::is_same<T, int64_t>::value) {
+      return uniform_int_transformation<T>(generator->random64());
+    } else {
+      return uniform_int_transformation<T>(generator->random());
+    }
+  }
+
+};
+
+/**
+ * Samples a uniform distribution in the range [from, to) of type T
  */
 template <typename T>
 struct uniform_real_distribution {
 
-  inline uniform_real_distribution(T a_in, T b_in) {
-    TORCH_CHECK(a_in <= b_in);
-    TORCH_CHECK(b_in-a_in <= std::numeric_limits<T>::max());
-    a = a_in;
-    b = b_in;
+  C10_HOST_DEVICE inline uniform_real_distribution(T from, T to) {
+    TORCH_CHECK_IF_NOT_ON_CUDA(from <= to);
+    TORCH_CHECK_IF_NOT_ON_CUDA(to - from <= std::numeric_limits<T>::max());
+    from_ = from;
+    to_ = to;
   }
 
   template <typename RNG>
-  inline dist_acctype<T> operator()(RNG* generator){
-    dist_acctype<T> x;
+  C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){
     if(std::is_same<T, double>::value) {
-      x = (generator->random64() & DOUBLE_MASK) * DOUBLE_DIVISOR;
+      return uniform_real_transformation<T>(generator->random64(), from_, to_);
     } else {
-      x = (generator->random() & FLOAT_MASK) * FLOAT_DIVISOR;
+      return uniform_real_transformation<T>(generator->random(), from_, to_);
     }
-    return (x * (b - a) + a);
   }
 
   private:
-    T a;
-    T b;
+    T from_;
+    T to_;
 };
 
 /**
@@ -116,14 +135,15 @@
 struct normal_distribution {
 
   inline normal_distribution(T mean_in, T stdv_in) {
-    TORCH_CHECK(stdv_in > 0);
+    TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in > 0);
     mean = mean_in;
     stdv = stdv_in;
   }
 
   template <typename RNG>
-  inline dist_acctype<T> operator()(RNG* generator){
+  inline dist_acctype<T> operator()(RNG generator){
     dist_acctype<T> ret;
+#if !defined(__CUDACC__) && !defined(__HIPCC__)
     // return cached values if available
     if (std::is_same<T, double>::value) {
       if (generator->next_double_normal_sample()) {
@@ -140,12 +160,14 @@
         return ret;
       }
     }
+#endif
     // otherwise generate new normal values
     uniform_real_distribution<T> uniform(0.0, 1.0);
     const dist_acctype<T> u1 = uniform(generator);
     const dist_acctype<T> u2 = uniform(generator);
     const dist_acctype<T> r = ::sqrt(static_cast<T>(-2.0) * ::log(static_cast<T>(1.0)-u2));
     const dist_acctype<T> theta = static_cast<T>(2.0) * static_cast<T>(M_PI) * u1;
+#if !defined(__CUDACC__) && !defined(__HIPCC__)
     if (std::is_same<T, double>::value) {
       dist_acctype<double> cache = r * ::sin(theta);
       generator->set_next_double_normal_sample(c10::optional<double>(cache));
@@ -153,6 +175,7 @@
       dist_acctype<float> cache = r * ::sin(theta);
       generator->set_next_float_normal_sample(c10::optional<float>(cache));
     }
+#endif
     ret = r * ::cos(theta) * stdv + mean;
     return ret;
   }
@@ -174,7 +197,7 @@
   }
 
   template <typename RNG>
-  inline int operator()(RNG* generator) {
+  inline int operator()(RNG generator) {
     uniform_real_distribution<T> uniform(0.0, 1.0);
     return uniform(generator) < p;
   }
@@ -195,7 +218,7 @@
   }
 
   template <typename RNG>
-  inline int operator()(RNG* generator) {
+  inline int operator()(RNG generator) {
     uniform_real_distribution<T> uniform(0.0, 1.0);
     dist_acctype<T> sample = uniform(generator);
     return static_cast<int>(::log(static_cast<T>(1.0)-sample) / ::log(p)) + 1;
@@ -216,7 +239,7 @@
   }
 
   template <typename RNG>
-  __ubsan_ignore_float_divide_by_zero__ inline T operator()(RNG* generator) {
+  __ubsan_ignore_float_divide_by_zero__ inline T operator()(RNG generator) {
     uniform_real_distribution<T> uniform(0.0, 1.0);
     dist_acctype<T> sample = uniform(generator);
     return static_cast<T>(-1.0) / lambda * ::log(static_cast<T>(1.0)-sample);
@@ -238,7 +261,7 @@
   }
 
   template <typename RNG>
-  inline T operator()(RNG* generator) {
+  inline T operator()(RNG generator) {
     uniform_real_distribution<T> uniform(0.0, 1.0);
     return median + sigma * ::tan(static_cast<T>(M_PI) * (uniform(generator)-static_cast<T>(0.5)));
   }
@@ -263,7 +286,7 @@
   }
 
   template<typename RNG>
-  inline T operator()(RNG* generator){
+  inline T operator()(RNG generator){
     normal_distribution<T> normal(mean, stdv);
     return ::exp(normal(generator));
   }
diff --git a/aten/src/ATen/core/TransformationHelper.h b/aten/src/ATen/core/TransformationHelper.h
new file mode 100644
index 0000000..7a3dd2a
--- /dev/null
+++ b/aten/src/ATen/core/TransformationHelper.h
@@ -0,0 +1,75 @@
+#include <c10/macros/Macros.h>
+#include <c10/util/Half.h>
+#include <c10/util/BFloat16.h>
+#include <limits>
+#include <cstdint>
+#include <cassert>
+
+namespace at {
+
+// Using DistAccumType in accumulate types for distributions.
+// Note: Ideally we'd be using ATen/AccumulateType.h but looks
+// like the there is some inconsistency in how accumulate types
+// are mapped currently, e.g. for the cpu side, float is mapped
+// to double.
+template <typename T>
+struct DistAccumType {  };
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+template <> struct DistAccumType<half> { using type = float; };
+#endif
+template <> struct DistAccumType<Half> { using type = float; };
+template <> struct DistAccumType<float> { using type = float; };
+template <> struct DistAccumType<double> { using type = double; };
+
+template <typename T>
+using dist_acctype = typename DistAccumType<T>::type;
+
+/**
+ * A transformation function for `torch.Tensor.random_()`, when both `from` and `to` are specified.
+ * `range` is `to - from`
+ * `base` is `from`
+ */
+template <typename T, typename V>
+C10_HOST_DEVICE inline T uniform_int_from_to_transformation(V val, uint64_t range, int64_t base) {
+  return static_cast<T>(static_cast<int64_t>((val % range) + base));
+}
+
+/**
+ * A transformation function for `torch.Tensor.random_()`, when `from=min_value(int64_t)` and to=None
+ */
+template <typename T, typename V>
+C10_HOST_DEVICE inline T uniform_int_full_range_transformation(V val) {
+  return static_cast<T>(static_cast<int64_t>(val));
+}
+
+/**
+ * A transformation function for `torch.Tensor.random_()`, when used without specifing `from` and `to`.
+ */
+template <typename T, typename V>
+C10_HOST_DEVICE inline T uniform_int_transformation(V val) {
+  if (std::is_same<T, bool>::value) {
+    return static_cast<bool>(val & 1);
+  } else if (std::is_same<T, double>::value) {
+    return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
+  } else if (std::is_same<T, int64_t>::value) {
+    return static_cast<T>(val % (static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1));
+  } else if (std::is_floating_point<T>::value || std::is_same<T, at::Half>::value || std::is_same<T, at::BFloat16>::value) {
+    return static_cast<T>(val % static_cast<uint64_t>((1ULL << std::numeric_limits<T>::digits) + 1));
+  } else if (std::is_integral<T>::value) {
+    return static_cast<T>(val % (static_cast<uint64_t>(std::numeric_limits<T>::max()) + 1));
+  } else {
+    assert(false);
+    return 0;
+  }
+}
+
+template <typename T, typename V>
+C10_HOST_DEVICE inline dist_acctype<T> uniform_real_transformation(V val, T from, T to) {
+  constexpr auto MASK = static_cast<V>((static_cast<uint64_t>(1) << std::numeric_limits<T>::digits) - 1);
+  constexpr auto DIVISOR = static_cast<T>(1) / (static_cast<uint64_t>(1) << std::numeric_limits<T>::digits);
+  dist_acctype<T> x = (val & MASK) * DIVISOR;
+  return (x * (to - from) + from);
+}
+
+} // namespace at
diff --git a/aten/src/ATen/native/cpu/DistributionTemplates.h b/aten/src/ATen/native/cpu/DistributionTemplates.h
index c78e06f..38b88a4 100644
--- a/aten/src/ATen/native/cpu/DistributionTemplates.h
+++ b/aten/src/ATen/native/cpu/DistributionTemplates.h
@@ -23,20 +23,10 @@
 void random_from_to_kernel(TensorIterator& iter, uint64_t range, int64_t base, RNG generator) {
   AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_from_to_kernel_cpu", [&] {
     std::lock_guard<std::mutex> lock(generator->mutex_);
-    if ((
-      std::is_same<scalar_t, int64_t>::value ||
-      std::is_same<scalar_t, double>::value ||
-      std::is_same<scalar_t, float>::value ||
-      std::is_same<scalar_t, at::BFloat16>::value) && range >= 1ULL << 32)
-    {
-      cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
-        return static_cast<scalar_t>(static_cast<int64_t>((generator->random64() % range) + base));
-      });
-    } else {
-      cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
-        return static_cast<scalar_t>(static_cast<int64_t>((generator->random() % range) + base));
-      });
-    }
+    cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
+      uniform_int_from_to_distribution<scalar_t> random(range, base);
+      return random(generator);
+    });
   });
 }
 
@@ -52,7 +42,8 @@
         std::is_same<scalar_t, float>::value ||
         std::is_same<scalar_t, at::BFloat16>::value) {
       cpu_serial_kernel(iter, [generator]() -> scalar_t {
-        return static_cast<scalar_t>(static_cast<int64_t>(generator->random64()));
+        uniform_int_full_range_distribution<scalar_t> random;
+        return random(generator);
       });
     } else {
       TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16");
@@ -73,37 +64,12 @@
 template<typename RNG>
 void random_kernel(TensorIterator& iter, RNG generator) {
   std::lock_guard<std::mutex> lock(generator->mutex_);
-  if (isFloatingType(iter.dtype())) {
-    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_kernel_fp_cpu", [&] {
-      if (std::is_same<scalar_t, double>::value) {
-        cpu_serial_kernel(iter, [generator]() -> scalar_t {
-          return static_cast<scalar_t>(generator->random64() % static_cast<uint64_t>((1ULL << std::numeric_limits<scalar_t>::digits) + 1));
-        });
-      } else {
-        cpu_serial_kernel(iter, [generator]() -> scalar_t {
-          return static_cast<scalar_t>(generator->random() % static_cast<uint64_t>((1ULL << std::numeric_limits<scalar_t>::digits) + 1));
-        });
-      }
+  AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] {
+    cpu_serial_kernel(iter, [generator]() -> scalar_t {
+      uniform_int_distribution<scalar_t> random;
+      return random(generator);
     });
-  } else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) {
-    AT_DISPATCH_INTEGRAL_TYPES_AND(at::ScalarType::Bool, iter.dtype(), "random_kernel_int_cpu", [&] {
-      if (std::is_same<scalar_t, int64_t>::value) {
-        cpu_serial_kernel(iter, [generator]() -> scalar_t {
-          return static_cast<scalar_t>(generator->random64() % (static_cast<uint64_t>(std::numeric_limits<scalar_t>::max()) + 1));
-        });
-      } else if (std::is_same<scalar_t, bool>::value) {
-        cpu_serial_kernel(iter, [generator]() -> scalar_t {
-          return static_cast<scalar_t>(generator->random() & 1);
-        });
-      } else {
-        cpu_serial_kernel(iter, [generator]() -> scalar_t {
-          return static_cast<scalar_t>(generator->random() % (static_cast<uint64_t>(std::numeric_limits<scalar_t>::max()) + 1));
-        });
-      }
-    });
-  } else {
-    TORCH_CHECK(false, "random_kernel_cpu handles only integral, floating-point and boolean types");
-  }
+  });
 }
 
 template<typename RNG>
diff --git a/aten/src/ATen/native/cuda/DistributionTemplates.h b/aten/src/ATen/native/cuda/DistributionTemplates.h
index 19b86de..4d768cb 100644
--- a/aten/src/ATen/native/cuda/DistributionTemplates.h
+++ b/aten/src/ATen/native/cuda/DistributionTemplates.h
@@ -8,6 +8,7 @@
 #include <ATen/cuda/CUDAContext.h>
 #include <ATen/cuda/detail/OffsetCalculator.cuh>
 #include <ATen/detail/FunctionTraits.h>
+#include <ATen/core/DistributionsHelper.h>
 
 #include <curand.h>
 #include <curand_kernel.h>
@@ -286,7 +287,7 @@
     {
       // define lambda to mod with range and add base
       auto random_func = [range, base] __device__ (uint64_t rand) {
-        return static_cast<scalar_t>(static_cast<int64_t>(rand % range + base));
+        return uniform_int_from_to_transformation<scalar_t>(rand, range, base);
       };
       distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
         gen,
@@ -300,7 +301,7 @@
         random_func);
     } else {
       auto random_func = [range, base] __device__ (uint32_t rand) {
-        return static_cast<scalar_t>(static_cast<int64_t>(rand % range + base));
+        return uniform_int_from_to_transformation<scalar_t>(rand, range, base);
       };
       distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
         gen,
@@ -329,7 +330,7 @@
         std::is_same<scalar_t, float>::value ||
         std::is_same<scalar_t, at::BFloat16>::value) {
       auto random_func = [] __device__ (uint64_t rand) {
-        return static_cast<scalar_t>(static_cast<int64_t>(rand));
+        return uniform_int_full_range_transformation<scalar_t>(rand);
       };
       distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
         gen,
@@ -365,75 +366,32 @@
     TORCH_CHECK(false, "random_() is not supported for bfloat16 CUDA tensors on Windows. Please see https://github.com/pytorch/pytorch/issues/33793");
   }
 #endif
-  if (isFloatingType(iter.dtype())) {
-    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_kernel_fp_cuda", [&] {
-      if (std::is_same<scalar_t, double>::value) {
-        auto random_func = [] __device__ (uint64_t rand) {
-          return static_cast<scalar_t>(rand % static_cast<uint64_t>((1ULL << std::numeric_limits<scalar_t>::digits) + 1));
-        };
-        distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
-          gen,
-          [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
-            ulonglong2 ret;
-            uint4 rand_val = curand4(state);
-            ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
-            ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
-            return ret;
-          },
-          random_func);
-      } else {
-        auto random_func = [] __device__ (uint32_t rand) {
-          return static_cast<scalar_t>(rand % static_cast<uint64_t>((1ULL << std::numeric_limits<scalar_t>::digits) + 1));
-        };
-        distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
-          gen,
-          [] __device__ (curandStatePhilox4_32_10_t* state) {
-            return curand4(state);
-          },
-          random_func);
-      }
-    });
-  } else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) {
-    AT_DISPATCH_INTEGRAL_TYPES_AND(at::ScalarType::Bool, iter.dtype(), "random_kernel_int_cuda", [&] {
-      if (std::is_same<scalar_t, int64_t>::value) {
-        auto random_func = [] __device__ (uint64_t rand) {
-          return static_cast<scalar_t>(rand % (static_cast<uint64_t>(std::numeric_limits<scalar_t>::max()) + 1));
-        };
-        distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
-          gen,
-          [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
-            ulonglong2 ret;
-            uint4 rand_val = curand4(state);
-            ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
-            ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
-            return ret;
-          },
-          random_func);
-      } else if (std::is_same<scalar_t, bool>::value) {
-        auto random_func = [] __device__ (uint32_t rand) {
-          return static_cast<scalar_t>(rand & 1);
-        };
-        distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
-          gen,
-          [] __device__ (curandStatePhilox4_32_10_t* state) {
-            return curand4(state);
-          },
-          random_func);
-      } else {
-        auto random_func = [] __device__ (uint32_t rand) {
-          return static_cast<scalar_t>(rand % (static_cast<uint64_t>(std::numeric_limits<scalar_t>::max()) + 1));
-        };
-        distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
-          gen,
-          [] __device__ (curandStatePhilox4_32_10_t* state) {
-            return curand4(state);
-          },
-          random_func);
-      }
-    });
-  } else {
-    TORCH_CHECK(false, "random_kernel_cuda handles only integral, floating-point and boolean types");
-  }
+  AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] {
+    if (std::is_same<scalar_t, double>::value || std::is_same<scalar_t, int64_t>::value) {
+      auto random_func = [] __device__ (uint64_t rand) {
+        return uniform_int_transformation<scalar_t>(rand);
+      };
+      distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter, gen,
+        [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
+          ulonglong2 ret;
+          uint4 rand_val = curand4(state);
+          ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
+          ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
+          return ret;
+        },
+        random_func);
+    } else {
+      auto random_func = [] __device__ (uint32_t rand) {
+        return uniform_int_transformation<scalar_t>(rand);
+      };
+      distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
+        gen,
+        [] __device__ (curandStatePhilox4_32_10_t* state) {
+          return curand4(state);
+        },
+        random_func);
+    }
+  });
 }
 
 template<typename RNG>
diff --git a/c10/core/GeneratorImpl.h b/c10/core/GeneratorImpl.h
index 9889625..fff105a 100644
--- a/c10/core/GeneratorImpl.h
+++ b/c10/core/GeneratorImpl.h
@@ -86,7 +86,7 @@
     return pyobj_;
   }
 
-  private:
+  protected:
     Device device_;
     DispatchKeySet key_set_;
     PyObject* pyobj_ = nullptr;
diff --git a/c10/util/Exception.h b/c10/util/Exception.h
index 39adf92..54086b7 100644
--- a/c10/util/Exception.h
+++ b/c10/util/Exception.h
@@ -272,6 +272,15 @@
 #endif
 #define TORCH_CHECK(cond, ...) TORCH_CHECK_WITH(Error, cond, __VA_ARGS__)
 
+// An utility macro that does what `TORCH_CHECK` does if compiled in the host code, 
+// otherwise does nothing. Supposed to be used in the code shared between host and
+// device code as an alternative for `TORCH_CHECK`.
+#if defined(__CUDACC__) || defined(__HIPCC__)
+#define TORCH_CHECK_IF_NOT_ON_CUDA(cond, ...)
+#else
+#define TORCH_CHECK_IF_NOT_ON_CUDA(cond, ...) TORCH_CHECK(cond, __VA_ARGS__)
+#endif
+
 // Debug only version of TORCH_INTERNAL_ASSERT. This macro only checks in debug
 // build, and does nothing in release build.  It is appropriate to use
 // in situations where you want to add an assert to a hotpath, but it is
diff --git a/test/cpp_extensions/rng_extension.cpp b/test/cpp_extensions/rng_extension.cpp
index ba74173..c16e35e 100644
--- a/test/cpp_extensions/rng_extension.cpp
+++ b/test/cpp_extensions/rng_extension.cpp
@@ -53,7 +53,8 @@
   return instance_count;
 }
 
-static auto registry = torch::RegisterOperators()
+void registerOps() {
+  static auto registry = torch::RegisterOperators()
       .op(torch::RegisterOperators::options()
         .schema("aten::random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)")
         .impl_unboxedOnlyKernel<decltype(random_from_to), &random_from_to>(DispatchKey::CustomRNGKeyId))
@@ -63,8 +64,10 @@
       .op(torch::RegisterOperators::options()
         .schema("aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)")
         .impl_unboxedOnlyKernel<decltype(random_), &random_>(DispatchKey::CustomRNGKeyId));
+}
 
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("registerOps", &registerOps);
   m.def("createTestCPUGenerator", &createTestCPUGenerator);
   m.def("getInstanceCount", &getInstanceCount);
   m.def("identity", &identity);
diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py
index 0cc5e6f..c860e95 100644
--- a/test/test_cpp_extensions_aot.py
+++ b/test/test_cpp_extensions_aot.py
@@ -140,6 +140,10 @@
 
 class TestRNGExtension(common.TestCase):
 
+    def setUp(self):
+        super(TestRNGExtension, self).setUp()
+        rng_extension.registerOps()
+
     def test_rng(self):
         fourty_two = torch.full((10,), 42, dtype=torch.int64)