Add Vulkan job dispatch and flush. (#46008)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46008

Test Plan: Imported from OSS

Reviewed By: IvanKobzarev

Differential Revision: D24291507

Pulled By: AshkanAliabadi

fbshipit-source-id: a3d02e76708a38e49398bb71e31bb2ad676d01af
diff --git a/aten/src/ATen/native/vulkan/api/Command.cpp b/aten/src/ATen/native/vulkan/api/Command.cpp
index 4461240..cdf96f8 100644
--- a/aten/src/ATen/native/vulkan/api/Command.cpp
+++ b/aten/src/ATen/native/vulkan/api/Command.cpp
@@ -242,17 +242,21 @@
 }
 
 void Command::Buffer::dispatch(
-    const Shader::WorkGroup& work_group) {
+    const Shader::WorkGroup& global_work_group) {
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
       command_buffer_,
       "This command buffer is in an invalid state! "
       "Potential reason: This command buffer is moved from.");
 
+  static const auto div_round_up = [](const uint32_t n, const uint32_t d) {
+    return (n + d - 1u) / d;
+  };
+
   vkCmdDispatch(
       command_buffer_,
-      work_group.x,
-      work_group.y,
-      work_group.z);
+      div_round_up(global_work_group.x, bound_.pipeline.local_work_group.x),
+      div_round_up(global_work_group.y, bound_.pipeline.local_work_group.y),
+      div_round_up(global_work_group.z, bound_.pipeline.local_work_group.z));
 }
 
 void Command::Buffer::submit(
diff --git a/aten/src/ATen/native/vulkan/api/Command.h b/aten/src/ATen/native/vulkan/api/Command.h
index aaec2df..69c5238 100644
--- a/aten/src/ATen/native/vulkan/api/Command.h
+++ b/aten/src/ATen/native/vulkan/api/Command.h
@@ -31,7 +31,7 @@
     void bind(const Pipeline::Object& pipeline);
     void bind(const Descriptor::Set& set);
     void copy(Resource::Buffer::Object source, Resource::Buffer::Object destination);
-    void dispatch(const Shader::WorkGroup& work_group);
+    void dispatch(const Shader::WorkGroup& global_work_group);
     void submit(VkQueue queue, Resource::Fence fence = {});
 
    private:
diff --git a/aten/src/ATen/native/vulkan/api/Common.h b/aten/src/ATen/native/vulkan/api/Common.h
index cbd53e8..f08a08b 100644
--- a/aten/src/ATen/native/vulkan/api/Common.h
+++ b/aten/src/ATen/native/vulkan/api/Common.h
@@ -2,11 +2,19 @@
 
 #include <ATen/ATen.h>
 
+#ifdef USE_VULKAN_SHADERC_RUNTIME
+#include <ATen/native/vulkan/glsl.h>
+#define VK_KERNEL(name) { name##_glsl, }
+#else
+#include <ATen/native/vulkan/spv.h>
+#define VK_KERNEL(name) { name##_spv, name##_spv_len, }
+#endif /* USE_VULKAN_SHADERC_RUNTIME */
+
 #ifdef USE_VULKAN_WRAPPER
 #include <vulkan_wrapper.h>
 #else
 #include <vulkan/vulkan.h>
-#endif
+#endif /* USE_VULKAN_WRAPPER */
 
 #define VK_CHECK(function)                                  \
   {                                                         \
diff --git a/aten/src/ATen/native/vulkan/api/Context.cpp b/aten/src/ATen/native/vulkan/api/Context.cpp
index d0fa08d..f28d881 100644
--- a/aten/src/ATen/native/vulkan/api/Context.cpp
+++ b/aten/src/ATen/native/vulkan/api/Context.cpp
@@ -78,25 +78,48 @@
 
 } // namespace
 
-void Context::Deleter::operator()(const VkDevice device) const {
-  // No VK_CHECK.  Don't want an exception thrown in the destructor.
-  vkDeviceWaitIdle(device);
-  vkDestroyDevice(device, nullptr);
-}
-
 Context::Context(const Adapter& adapter)
     : adapter_(adapter),
       device_(
           create_device(
               adapter.handle,
               adapter.compute_queue_family_index),
-          Deleter{}),
+          &VK_DELETER(Device)),
       queue_(acquire_queue(device(), adapter.compute_queue_family_index)),
       command_(gpu()),
       shader_(gpu()),
       pipeline_(gpu()),
       descriptor_(gpu()),
       resource_(gpu()) {
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      device_,
+      "Invalid Vulkan device!");
+}
+
+Context::~Context() {
+  try {
+    flush();
+  }
+  catch (const std::exception& e) {
+    LOG(WARNING)
+        << "Vulkan: Context destructor raised an exception!  Error: "
+        << e.what();
+  }
+  catch (...) {
+    LOG(WARNING) << "Vulkan: Context destructor raised an unknown exception!";
+  }
+}
+
+void Context::flush() {
+  VK_CHECK(vkDeviceWaitIdle(device()));
+
+  resource().pool.purge();
+  descriptor().pool.purge();
+  command().pool.purge();
+}
+
+bool available() {
+  return context();
 }
 
 Context* context() {
@@ -106,6 +129,40 @@
   return context;
 }
 
+Descriptor::Set dispatch_prologue(
+    Command::Buffer& command_buffer,
+    const Shader::Layout::Signature& shader_layout_signature,
+    const Shader::Descriptor& shader_descriptor,
+    const Shader::WorkGroup& local_work_group) {
+  Descriptor& descriptor = context()->descriptor();
+  Pipeline& pipeline = context()->pipeline();
+  Shader& shader = context()->shader();
+
+  const Shader::Layout::Object shader_layout =
+      shader.layout.cache.retrieve({
+        shader_layout_signature,
+      });
+
+  command_buffer.bind(
+      pipeline.cache.retrieve({
+        pipeline.layout.cache.retrieve({
+          shader_layout.handle,
+        }),
+        shader.cache.retrieve(shader_descriptor),
+        local_work_group,
+      }));
+
+  return descriptor.pool.allocate(shader_layout);
+}
+
+void dispatch_epilogue(
+    Command::Buffer& command_buffer,
+    const Descriptor::Set& descriptor_set,
+    const Shader::WorkGroup& global_work_group) {
+  command_buffer.bind(descriptor_set);
+  command_buffer.dispatch(global_work_group);
+}
+
 } // namespace api
 } // namespace vulkan
 } // namespace native
diff --git a/aten/src/ATen/native/vulkan/api/Context.h b/aten/src/ATen/native/vulkan/api/Context.h
index c272312..2be8557 100644
--- a/aten/src/ATen/native/vulkan/api/Context.h
+++ b/aten/src/ATen/native/vulkan/api/Context.h
@@ -29,7 +29,7 @@
   Context(Context&&) = default;
   Context& operator=(const Context&) = delete;
   Context& operator=(Context&&) = default;
-  ~Context() = default;
+  ~Context();
 
   GPU gpu();
   Command& command();
@@ -38,20 +38,31 @@
   Descriptor& descriptor();
   Resource& resource();
 
+  // GPU RPC
+
+  template<typename... Arguments>
+  void dispatch(
+      Command::Buffer& command_buffer,
+      const Shader::Layout::Signature& shader_layout_signature,
+      const Shader::Descriptor& shader_descriptor,
+      const Shader::WorkGroup& local_work_group,
+      const Shader::WorkGroup& global_work_group,
+      Arguments&&... arguments);
+
+  // This function is expensive and its use consequential for performance. Only
+  // use this function for debugging or as a short term hack on way to a more
+  // performant solution.
+
+  void flush();
+
  private:
   VkDevice device();
   VkQueue queue();
 
  private:
-  class Deleter final {
-   public:
-    void operator()(VkDevice device) const;
-  };
-
- private:
   // Construction and destruction order matters.  Do not move members around.
   Adapter adapter_;
-  Handle<VkDevice, Deleter> device_;
+  Handle<VkDevice, decltype(&VK_DELETER(Device))> device_;
   VkQueue queue_;
   Command command_;
   Shader shader_;
@@ -60,6 +71,7 @@
   Resource resource_;
 };
 
+bool available();
 Context* context();
 
 //
@@ -105,6 +117,62 @@
   return queue_;
 }
 
+namespace detail {
+
+template<
+    size_t...Indices,
+    typename ...Arguments>
+inline void bind(
+    Descriptor::Set& descriptor_set,
+    const std::index_sequence<Indices...>,
+    Arguments&&...arguments) {
+  C10_UNUSED const int _[]{
+    (descriptor_set.bind(Indices, arguments), 0)...,
+  };
+}
+
+} // namespace detail
+
+template<typename... Arguments>
+inline void Context::dispatch(
+    Command::Buffer& command_buffer,
+    const Shader::Layout::Signature& shader_layout_signature,
+    const Shader::Descriptor& shader_descriptor,
+    const Shader::WorkGroup& local_work_group,
+    const Shader::WorkGroup& global_work_group,
+    Arguments&&... arguments) {
+  // Forward declaration
+  Descriptor::Set dispatch_prologue(
+      Command::Buffer&,
+      const Shader::Layout::Signature&,
+      const Shader::Descriptor&,
+      const Shader::WorkGroup&);
+
+  // Factor out template parameter independent code to minimize code bloat.
+  Descriptor::Set descriptor_set = dispatch_prologue(
+      command_buffer,
+      shader_layout_signature,
+      shader_descriptor,
+      local_work_group);
+
+  detail::bind(
+      descriptor_set,
+      std::index_sequence_for<Arguments...>{},
+      std::forward<Arguments>(arguments)...);
+
+  // Forward declaration
+  void dispatch_epilogue(
+      Command::Buffer&,
+      const Descriptor::Set&,
+      const Shader::WorkGroup&);
+
+  // Factor out template parameter independent code to minimize code bloat.
+  dispatch_epilogue(
+      command_buffer,
+      descriptor_set,
+      global_work_group);
+}
+
 } // namespace api
 } // namespace vulkan
 } // namespace native
diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.cpp b/aten/src/ATen/native/vulkan/api/Descriptor.cpp
index ab136c0..037f793 100644
--- a/aten/src/ATen/native/vulkan/api/Descriptor.cpp
+++ b/aten/src/ATen/native/vulkan/api/Descriptor.cpp
@@ -119,13 +119,15 @@
 Descriptor::Set::Set(
     const VkDevice device,
     const VkDescriptorPool descriptor_pool,
-    const VkDescriptorSetLayout descriptor_set_layout)
+    const Shader::Layout::Object& shader_layout)
   : device_(device),
     descriptor_set_(
         allocate_descriptor_set(
             device_,
             descriptor_pool,
-            descriptor_set_layout)) {
+            shader_layout.handle)),
+    shader_layout_signature_(shader_layout.signature),
+    bindings_{} {
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
       descriptor_set_,
       "Invalid Vulkan descriptor set!");
@@ -156,7 +158,6 @@
 
 Descriptor::Set& Descriptor::Set::bind(
     const uint32_t binding,
-    const VkDescriptorType type,
     const Resource::Buffer::Object& buffer) {
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
       device_,
@@ -165,7 +166,7 @@
 
   update({
       binding,
-      type,
+      shader_layout_signature_[binding],
       {
         .buffer = {
           buffer.handle,
@@ -180,7 +181,6 @@
 
 Descriptor::Set& Descriptor::Set::bind(
     const uint32_t binding,
-    const VkDescriptorType type,
     const Resource::Image::Object& image) {
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
       device_,
@@ -189,7 +189,7 @@
 
   update({
       binding,
-      type,
+      shader_layout_signature_[binding],
       {
         .image = {
           image.sampler,
@@ -309,7 +309,7 @@
 }
 
 Descriptor::Set Descriptor::Pool::allocate(
-    const VkDescriptorSetLayout descriptor_set_layout)
+    const Shader::Layout::Object& shader_layout)
 {
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
       device_ && descriptor_pool_,
@@ -317,13 +317,13 @@
       "Potential reason: This descriptor pool is moved from.");
 
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
-      descriptor_set_layout,
-      "Invalid Vulkan descriptor set layout!");
+      shader_layout,
+      "Invalid Vulkan shader layout!");
 
   return Set(
       device_,
       descriptor_pool_.get(),
-      descriptor_set_layout);
+      shader_layout);
 }
 
 void Descriptor::Pool::purge() {
diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.h b/aten/src/ATen/native/vulkan/api/Descriptor.h
index 72550c5..c005e1f 100644
--- a/aten/src/ATen/native/vulkan/api/Descriptor.h
+++ b/aten/src/ATen/native/vulkan/api/Descriptor.h
@@ -2,6 +2,7 @@
 
 #include <ATen/native/vulkan/api/Common.h>
 #include <ATen/native/vulkan/api/Resource.h>
+#include <ATen/native/vulkan/api/Shader.h>
 
 namespace at {
 namespace native {
@@ -58,7 +59,7 @@
     Set(
         VkDevice device,
         VkDescriptorPool descriptor_pool,
-        VkDescriptorSetLayout descriptor_set_layout);
+        const Shader::Layout::Object& shader_layout);
     Set(const Set&) = delete;
     Set& operator=(const Set&) = delete;
     Set(Set&&);
@@ -67,12 +68,10 @@
 
     Set& bind(
         uint32_t binding,
-        VkDescriptorType type,
         const Resource::Buffer::Object& buffer);
 
     Set& bind(
         uint32_t binding,
-        VkDescriptorType type,
         const Resource::Image::Object& image);
 
     VkDescriptorSet handle() const;
@@ -92,6 +91,7 @@
    private:
     VkDevice device_;
     VkDescriptorSet descriptor_set_;
+    Shader::Layout::Signature shader_layout_signature_;
 
     struct {
       c10::SmallVector<Item, 8u> items;
@@ -112,7 +112,7 @@
     Pool& operator=(Pool&&);
     ~Pool() = default;
 
-    Set allocate(VkDescriptorSetLayout descriptor_set_layout);
+    Set allocate(const Shader::Layout::Object& shader_layout);
     void purge();
 
    private:
diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.cpp b/aten/src/ATen/native/vulkan/api/Pipeline.cpp
index d7067c2..93c9543 100644
--- a/aten/src/ATen/native/vulkan/api/Pipeline.cpp
+++ b/aten/src/ATen/native/vulkan/api/Pipeline.cpp
@@ -132,7 +132,7 @@
     3u,
     specialization_map_entires,
     sizeof(Shader::WorkGroup),
-    &descriptor.work_group,
+    &descriptor.local_work_group,
   };
 
   const VkComputePipelineCreateInfo compute_pipeline_create_info{
diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.h b/aten/src/ATen/native/vulkan/api/Pipeline.h
index b168b91..5e3e419 100644
--- a/aten/src/ATen/native/vulkan/api/Pipeline.h
+++ b/aten/src/ATen/native/vulkan/api/Pipeline.h
@@ -99,7 +99,7 @@
   struct Descriptor final {
     VkPipelineLayout pipeline_layout;
     VkShaderModule shader_module;
-    Shader::WorkGroup work_group;
+    Shader::WorkGroup local_work_group;
   };
 
   /*
@@ -132,6 +132,7 @@
   struct Object final {
     VkPipeline handle;
     VkPipelineLayout layout;
+    Shader::WorkGroup local_work_group;
 
     operator bool() const;
   };
@@ -182,7 +183,7 @@
     const Pipeline::Descriptor& _2) {
   return (_1.pipeline_layout == _2.pipeline_layout) &&
          (_1.shader_module == _2.shader_module) &&
-         (_1.work_group == _2.work_group);
+         (_1.local_work_group == _2.local_work_group);
 }
 
 inline size_t Pipeline::Factory::Hasher::operator()(
@@ -190,9 +191,14 @@
   return c10::get_hash(
       descriptor.pipeline_layout,
       descriptor.shader_module,
-      descriptor.work_group.x,
-      descriptor.work_group.y,
-      descriptor.work_group.z);
+      descriptor.local_work_group.x,
+      descriptor.local_work_group.y,
+      descriptor.local_work_group.z);
+}
+
+inline Pipeline::Object::operator bool() const {
+  return (VK_NULL_HANDLE != handle) &&
+         (VK_NULL_HANDLE != layout);
 }
 
 inline Pipeline::Object Pipeline::Cache::retrieve(
@@ -200,6 +206,7 @@
   return {
     cache_.retrieve(descriptor),
     descriptor.pipeline_layout,
+    descriptor.local_work_group,
   };
 }
 
@@ -207,11 +214,6 @@
   cache_.purge();
 }
 
-inline Pipeline::Object::operator bool() const {
-  return (VK_NULL_HANDLE != handle) &&
-         (VK_NULL_HANDLE != layout);
-}
-
 } // namespace api
 } // namespace vulkan
 } // namespace native
diff --git a/aten/src/ATen/native/vulkan/api/Runtime.cpp b/aten/src/ATen/native/vulkan/api/Runtime.cpp
index ce6e3b4..ce5fdf0 100644
--- a/aten/src/ATen/native/vulkan/api/Runtime.cpp
+++ b/aten/src/ATen/native/vulkan/api/Runtime.cpp
@@ -323,10 +323,6 @@
   return runtime.get();
 }
 
-bool available() {
-  return initialize();
-}
-
 Runtime* runtime() {
   Runtime* const runtime = initialize();
   TORCH_CHECK(
diff --git a/aten/src/ATen/native/vulkan/api/Runtime.h b/aten/src/ATen/native/vulkan/api/Runtime.h
index ffc031f..4d06da8 100644
--- a/aten/src/ATen/native/vulkan/api/Runtime.h
+++ b/aten/src/ATen/native/vulkan/api/Runtime.h
@@ -55,7 +55,6 @@
   Handle<VkDebugReportCallbackEXT, Debug> debug_report_callback_;
 };
 
-bool available();
 Runtime* runtime();
 
 } // namespace api
diff --git a/aten/src/ATen/native/vulkan/api/Shader.cpp b/aten/src/ATen/native/vulkan/api/Shader.cpp
index e3c336d..2c090d0 100644
--- a/aten/src/ATen/native/vulkan/api/Shader.cpp
+++ b/aten/src/ATen/native/vulkan/api/Shader.cpp
@@ -9,7 +9,6 @@
 namespace vulkan {
 namespace api {
 
-
 Shader::Layout::Factory::Factory(const GPU& gpu)
   : device_(gpu.device) {
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
@@ -19,12 +18,25 @@
 
 Shader::Layout::Factory::Handle Shader::Layout::Factory::operator()(
     const Descriptor& descriptor) const {
+  c10::SmallVector<VkDescriptorSetLayoutBinding, 8u> bindings;
+
+  uint32_t binding = 0u;
+  for (const VkDescriptorType type : descriptor.signature) {
+    bindings.push_back({
+      binding++,
+      type,
+      1u,
+      VK_SHADER_STAGE_COMPUTE_BIT,
+      nullptr,
+    });
+  }
+
   const VkDescriptorSetLayoutCreateInfo descriptor_set_layout_create_info{
     VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO,
     nullptr,
     0u,
-    static_cast<uint32_t>(descriptor.bindings.size()),
-    descriptor.bindings.data(),
+    static_cast<uint32_t>(bindings.size()),
+    bindings.data(),
   };
 
   VkDescriptorSetLayout descriptor_set_layout{};
@@ -44,6 +56,10 @@
   };
 }
 
+Shader::Layout::Cache::Cache(Factory factory)
+  : cache_(std::move(factory)) {
+}
+
 #ifdef USE_VULKAN_SHADERC_RUNTIME
 
 struct Shader::Factory::Compiler final {
diff --git a/aten/src/ATen/native/vulkan/api/Shader.h b/aten/src/ATen/native/vulkan/api/Shader.h
index 599ed79..8de369c 100644
--- a/aten/src/ATen/native/vulkan/api/Shader.h
+++ b/aten/src/ATen/native/vulkan/api/Shader.h
@@ -39,11 +39,17 @@
 
   struct Layout final {
     /*
+      Signature
+    */
+
+    typedef c10::SmallVector<VkDescriptorType, 8u> Signature;
+
+    /*
       Descriptor
     */
 
     struct Descriptor final {
-      c10::SmallVector<VkDescriptorSetLayoutBinding, 16u> bindings;
+      Signature signature;
     };
 
     /*
@@ -68,12 +74,32 @@
       VkDevice device_;
     };
 
+    struct Object final {
+      VkDescriptorSetLayout handle;
+      Signature signature;
+
+      operator bool() const;
+    };
+
     /*
       Cache
     */
 
-    typedef api::Cache<Factory> Cache;
-    Cache cache;
+    class Cache final {
+     public:
+      explicit Cache(Factory factory);
+      Cache(const Cache&) = delete;
+      Cache& operator=(const Cache&) = delete;
+      Cache(Cache&&) = default;
+      Cache& operator=(Cache&&) = default;
+      ~Cache() = default;
+
+      Object retrieve(const Descriptor& descriptor);
+      void purge();
+
+     private:
+      api::Cache<Factory> cache_;
+    } cache;
 
     explicit Layout(const GPU& gpu)
       : cache(Factory(gpu)) {
@@ -165,27 +191,38 @@
 inline bool operator==(
     const Shader::Layout::Descriptor& _1,
     const Shader::Layout::Descriptor& _2) {
-  return _1.bindings == _2.bindings;
+  return _1.signature == _2.signature;
 }
 
 inline size_t Shader::Layout::Factory::Hasher::operator()(
     const Descriptor& descriptor) const {
   size_t hash = 0u;
 
-  for (const VkDescriptorSetLayoutBinding& binding : descriptor.bindings) {
+  for (const VkDescriptorType type : descriptor.signature) {
     hash = c10::hash_combine(
         hash,
-        c10::get_hash(
-            binding.binding,
-            binding.descriptorType,
-            binding.descriptorCount,
-            binding.stageFlags,
-            binding.pImmutableSamplers));
+        c10::get_hash(type));
   }
 
   return hash;
 }
 
+inline Shader::Layout::Object::operator bool() const {
+  return VK_NULL_HANDLE != handle;
+}
+
+inline Shader::Layout::Object Shader::Layout::Cache::retrieve(
+    const Descriptor& descriptor) {
+  return {
+    cache_.retrieve(descriptor),
+    descriptor.signature,
+  };
+}
+
+inline void Shader::Layout::Cache::purge() {
+  cache_.purge();
+}
+
 inline bool operator==(
     const Shader::WorkGroup& _1,
     const Shader::WorkGroup& _2) {