[pytorch-vulkan] add aten::randn_like & aten::normal_ (#109075)

Summary:
Implemented `aten::normal_` shader and used it to create `aten::randn_like`.

Ops defintions:
https://pytorch.org/docs/stable/generated/torch.randn_like.html
https://pytorch.org/docs/stable/generated/torch.Tensor.normal_.html

Test Plan:
```
[ttingchulin@53491.od /data/sandcastle/boxes/fbsource (randn)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin  -- --gtest_filter="*<test>*" eg.  -- --gtest_filter="*randn_like*"

[==========] Running 2 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 2 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.randn_like
[       OK ] VulkanAPITest.randn_like (230 ms)
[ RUN      ] VulkanAPITest.randn_like_large
[       OK ] VulkanAPITest.randn_like_large (570 ms)
[----------] 2 tests from VulkanAPITest (801 ms total)

[----------] Global test environment tear-down
[==========] 2 tests from 1 test suite ran. (801 ms total)
[  PASSED  ] 2 tests.

[ttingchulin@53491.od /data/sandcastle/boxes/fbsource (randn)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin  -- --gtest_filter="*<test>*" eg.  -- --gtest_filter="*normal_*"
[==========] Running 3 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 3 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.normal_
[       OK ] VulkanAPITest.normal_ (222 ms)
[ RUN      ] VulkanAPITest.normal_large
[       OK ] VulkanAPITest.normal_large (136 ms)
[ RUN      ] VulkanAPITest.normal_error
[       OK ] VulkanAPITest.normal_error (37 ms)
[----------] 3 tests from VulkanAPITest (396 ms total)

[----------] Global test environment tear-down
[==========] 3 tests f.
```

Reviewed By: yipjustin

Differential Revision: D48814024

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109075
Approved by: https://github.com/yipjustin
diff --git a/aten/src/ATen/native/vulkan/glsl/normal_.glsl b/aten/src/ATen/native/vulkan/glsl/normal_.glsl
new file mode 100644
index 0000000..b93ce4a
--- /dev/null
+++ b/aten/src/ATen/native/vulkan/glsl/normal_.glsl
@@ -0,0 +1,31 @@
+#version 450 core
+#define PRECISION $precision
+#define FORMAT $format
+
+#include "random.h"
+
+layout(std430) buffer;
+
+/* Qualifiers: layout - storage - precision - memory */
+
+layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict image3D uOutput;
+layout(set = 0, binding = 1) uniform PRECISION restrict Block {
+  ivec3 size;
+  float mean;
+  float std;
+} uBlock;
+
+layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
+
+void main() {
+  ivec3 pos = ivec3(gl_GlobalInvocationID);
+
+  if (all(lessThan(pos, uBlock.size))) {
+    vec4 v = vec4(
+        get_gaussrand(ivec4(pos, -20), uBlock.mean, uBlock.std),
+        get_gaussrand(ivec4(pos, 40), uBlock.mean, uBlock.std),
+        get_gaussrand(ivec4(pos, -30), uBlock.mean, uBlock.std),
+        get_gaussrand(ivec4(pos, 15), uBlock.mean, uBlock.std));
+    imageStore(uOutput, pos, v);
+  }
+}
diff --git a/aten/src/ATen/native/vulkan/glsl/random.h b/aten/src/ATen/native/vulkan/glsl/random.h
index eb9ef9e..14650e3 100644
--- a/aten/src/ATen/native/vulkan/glsl/random.h
+++ b/aten/src/ATen/native/vulkan/glsl/random.h
@@ -2,6 +2,9 @@
  * Random utility functions
  */
 
+// the epsilong defined for fp16 in PyTorch
+#define PI 3.14159265358979323846264
+
 uint pcg_hash(uint v) {
   // From: https://www.reedbeta.com/blog/hash-functions-for-gpu-rendering/
   uint state = v * 747796405u + 2891336453u;
@@ -15,7 +18,36 @@
   return fract(s / 1234567.0);
 }
 
+float rand2_nonzero(ivec4 pos) {
+  float v = rand2(pos);
+  int offset = 0;
+  while (v == 0.0) {
+    offset++;
+    v = rand2(ivec4(pos.x + offset, pos.y, pos.z, pos.w));
+  }
+  return v;
+}
+
 float get_uniform(ivec4 pos, float from, float to) {
   float v = rand2(pos);
   return from + v * (to - from);
 }
+
+float get_gaussrand(ivec4 pos, float mean, float std) {
+  // Implementation of Box-Muller transform from the pseudo from Wikipedia,
+  // which converts two uniformly sampled random numbers into two numbers of
+  // Gaussian distribution. Since the shader file can only use one for a position,
+  // we flip a coin by the 3rd uniformly sampled number to decide which one to keep.
+  // https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
+  float u1 = rand2_nonzero(pos);
+  float u2 = rand2_nonzero(ivec4(pos.x+10, pos.y+20, pos.z+30, pos.w+40));
+  float u3 = rand2_nonzero(ivec4(pos.x-10, pos.y-20, pos.z-30, pos.w-40));
+
+  float mag = std * sqrt(-2.0 * log(u1));
+  float v;
+  if (u3 > 0.5)
+    v = mag * cos(2.0 * PI * u2) + mean;
+  else
+    v = mag * sin(2.0 * PI * u2) + mean;
+  return v;
+}
diff --git a/aten/src/ATen/native/vulkan/ops/Random.cpp b/aten/src/ATen/native/vulkan/ops/Random.cpp
index a0a1653..954785e 100644
--- a/aten/src/ATen/native/vulkan/ops/Random.cpp
+++ b/aten/src/ATen/native/vulkan/ops/Random.cpp
@@ -71,11 +71,74 @@
   return input_arg.clone().detach().uniform_(0.0, 1.0);
 }
 
+Tensor& normal_(
+    Tensor& self,
+    const double mean,
+    const double std,
+    const c10::optional<at::Generator> /* not implemented */) {
+  TORCH_CHECK(
+      self.is_vulkan(),
+      "Vulkan: In-place operator is only supported on Vulkan tensors.");
+
+  TORCH_CHECK(
+      std >= 0,
+      "Vulkan: Standard deviation (std) can be negative.");
+
+  api::Context* const context = api::context();
+
+  vTensor& v_self = convert(self);
+
+  const struct Block final {
+    uvec3 extents;
+    float mean;
+    float std;
+  } block{v_self.extents(), static_cast<float>(mean), static_cast<float>(std)};
+
+  api::UniformParamsBuffer params(context, block);
+  api::PipelineBarrier pipeline_barrier{};
+
+  context->submit_compute_job(
+      // shader descriptor
+      // shader_descriptor,
+      VK_KERNEL(normal_),
+      // pipeline barrier
+      pipeline_barrier,
+      // global work group size
+      v_self.extents(),
+      // local work group size
+      adaptive_work_group_size(v_self.extents()),
+      // fence handle
+      VK_NULL_HANDLE,
+      // shader arguments
+      v_self.image(
+          pipeline_barrier,
+          api::PipelineStage::COMPUTE,
+          api::MemoryAccessType::WRITE),
+      // params buffer
+      params.buffer());
+
+  return self;
+}
+
+Tensor randn_like(
+    const at::Tensor& input_arg,
+    const c10::optional<c10::ScalarType> /* not implemented */,
+    const c10::optional<c10::Layout> /* not implemented */,
+    const c10::optional<c10::Device> /* not implemented */,
+    const c10::optional<bool> /* not implemented */,
+    const c10::optional<c10::MemoryFormat> /* not implemented */) {
+  // Returns a tensor with the same size as input that is filled with random
+  // numbers from a normal distribution with mean 0 and standard deviation 1.
+  return input_arg.clone().detach().normal_(0.0, 1.0);
+}
+
 #ifdef USE_VULKAN_API
 
 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
   m.impl(TORCH_SELECTIVE_NAME("aten::uniform_"), TORCH_FN(uniform_));
   m.impl(TORCH_SELECTIVE_NAME("aten::rand_like"), TORCH_FN(rand_like));
+  m.impl(TORCH_SELECTIVE_NAME("aten::normal_"), TORCH_FN(normal_));
+  m.impl(TORCH_SELECTIVE_NAME("aten::randn_like"), TORCH_FN(randn_like));
 }
 
 #endif /* USE_VULKAN_API */
diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp
index fa21380..810209a 100644
--- a/aten/src/ATen/test/vulkan_api_test.cpp
+++ b/aten/src/ATen/test/vulkan_api_test.cpp
@@ -3969,6 +3969,66 @@
   test_uniform(out_vulkan, a_min, a_max);
 }
 
+void test_normal(at::Tensor out_vulkan, const float mean, const float std) {
+  // Verify the distribution is normal. The difference between given mean vs generated mean should be within 5% of standard deviation, and the same for standard deviation itself.
+  ASSERT_TRUE(std::abs(at::mean(out_vulkan.cpu()).item<float>() - mean) < std::abs(std) * 0.05);
+  ASSERT_TRUE(std::abs(at::std(out_vulkan.cpu()).item<float>() - std) < std::abs(std) * 0.05);
+}
+
+TEST_F(VulkanAPITest, normal_) {
+  float a_mean = -10.0;
+  float a_std = 2.0;
+
+  auto a_vulkan =
+      at::zeros({3, 4, 5, 6}, at::device(at::kCPU).dtype(at::kFloat)).vulkan();
+  a_vulkan.normal_(a_mean, a_std);
+
+  test_normal(a_vulkan, a_mean, a_std);
+}
+
+TEST_F(VulkanAPITest, normal_large) {
+  float a_mean = 1.0;
+  float a_std = 0.001;
+
+  auto a_vulkan =
+      at::zeros({30, 40, 50, 60}, at::device(at::kCPU).dtype(at::kFloat)).vulkan();
+  a_vulkan.normal_(a_mean, a_std);
+
+  test_normal(a_vulkan, a_mean, a_std);
+}
+
+TEST_F(VulkanAPITest, normal_error) {
+  float a_mean = 1.0;
+  float a_std = -1;
+
+  auto a_vulkan =
+      at::zeros({30, 40, 50, 60}, at::device(at::kCPU).dtype(at::kFloat)).vulkan();
+  EXPECT_THROW(a_vulkan.normal_(a_mean, a_std), ::c10::Error);
+}
+
+TEST_F(VulkanAPITest, randn_like) {
+  float a_mean = 0.0;
+  float a_std = 1.0;
+
+  auto a_vulkan =
+      at::zeros({8, 7, 6, 5}, at::device(at::kCPU).dtype(at::kFloat)).vulkan();
+  const auto out_vulkan = at::randn_like(a_vulkan);
+  // verify that the input are still all zeros (not in-place)
+  ASSERT_TRUE(at::mean(a_vulkan.cpu()).item<float>() == 0.0);
+  test_normal(out_vulkan, a_mean, a_std);
+}
+
+TEST_F(VulkanAPITest, randn_like_large) {
+  float a_mean = 0.0;
+  float a_std = 1.0;
+
+  auto a_vulkan =
+      at::zeros({80, 70, 60, 50}, at::device(at::kCPU).dtype(at::kFloat)).vulkan();
+  const auto out_vulkan = at::randn_like(a_vulkan);
+
+  test_normal(out_vulkan, a_mean, a_std);
+}
+
 void test_t(const at::IntArrayRef input_shape) {
   const auto in_cpu = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat));
   const auto out_cpu = at::t(in_cpu);