Vulkan pipeline and pipeline layout cache. (#42395)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42395

Test Plan: Imported from OSS

Reviewed By: IvanKobzarev

Differential Revision: D23252334

Pulled By: AshkanAliabadi

fbshipit-source-id: 6b4e88f9794a7879d47a1cdb671076d50f1944d9
diff --git a/aten/src/ATen/native/vulkan/api/Context.cpp b/aten/src/ATen/native/vulkan/api/Context.cpp
index 075ea5e..47019a9 100644
--- a/aten/src/ATen/native/vulkan/api/Context.cpp
+++ b/aten/src/ATen/native/vulkan/api/Context.cpp
@@ -258,7 +258,8 @@
       compute_queue_family_index_(query_compute_queue_family_index(physical_device())),
       device_(create_device(physical_device(), compute_queue_family_index_), &VK_DELETER(Device)),
       queue_(acquire_queue(device(), compute_queue_family_index_)),
-      shader_(device()) {
+      shader_(device()),
+      pipeline_(device()) {
 }
 
 Context::Debug::Debug(const VkInstance instance)
diff --git a/aten/src/ATen/native/vulkan/api/Context.h b/aten/src/ATen/native/vulkan/api/Context.h
index d9d0f11..1a4c1ce 100644
--- a/aten/src/ATen/native/vulkan/api/Context.h
+++ b/aten/src/ATen/native/vulkan/api/Context.h
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <ATen/native/vulkan/api/Common.h>
+#include <ATen/native/vulkan/api/Pipeline.h>
 #include <ATen/native/vulkan/api/Shader.h>
 
 namespace at {
@@ -44,6 +45,10 @@
     return shader_;
   }
 
+  inline Pipeline& pipeline() {
+    return pipeline_;
+  }
+
  private:
   class Debug final {
    public:
@@ -64,6 +69,7 @@
   Handle<VkDevice, decltype(&VK_DELETER(Device))> device_;
   VkQueue queue_;
   Shader shader_;
+  Pipeline pipeline_;
 };
 
 C10_EXPORT bool available();
diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.cpp b/aten/src/ATen/native/vulkan/api/Pipeline.cpp
new file mode 100644
index 0000000..303eea7
--- /dev/null
+++ b/aten/src/ATen/native/vulkan/api/Pipeline.cpp
@@ -0,0 +1,127 @@
+#include <ATen/native/vulkan/api/Pipeline.h>
+
+namespace at {
+namespace native {
+namespace vulkan {
+namespace api {
+
+Pipeline::Layout::Factory::Factory(const VkDevice device)
+ : device_(device) {
+}
+
+typename Pipeline::Layout::Factory::Handle Pipeline::Layout::Factory::operator()(
+    const Descriptor& descriptor) const {
+  const VkPipelineLayoutCreateInfo pipeline_layout_create_info{
+    VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
+    nullptr,
+    0u,
+    1u,
+    &descriptor.descriptor_set_layout,
+    0u,
+    nullptr,
+  };
+
+  VkPipelineLayout pipeline_layout{};
+  VK_CHECK(vkCreatePipelineLayout(
+      device_, &pipeline_layout_create_info, nullptr, &pipeline_layout));
+
+  return Handle{
+    pipeline_layout,
+    Deleter(device_),
+  };
+}
+
+namespace {
+
+VkPipelineCache create_pipeline_cache(const VkDevice device) {
+  const VkPipelineCacheCreateInfo pipeline_cache_create_info{
+    VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO,
+    nullptr,
+    0u,
+    0u,
+    nullptr,
+  };
+
+  VkPipelineCache pipeline_cache{};
+  VK_CHECK(vkCreatePipelineCache(
+      device, &pipeline_cache_create_info, nullptr, &pipeline_cache));
+
+  return pipeline_cache;
+}
+
+} // namespace
+
+Pipeline::Factory::Factory(const VkDevice device)
+ : device_(device),
+   pipeline_cache_(create_pipeline_cache(device), VK_DELETER(PipelineCache)(device)) {
+}
+
+typename Pipeline::Factory::Handle Pipeline::Factory::operator()(
+    const Descriptor& descriptor) const {
+  constexpr uint32_t x_offset = 0u;
+  constexpr uint32_t x_size = sizeof(Shader::WorkGroup::x);
+  constexpr uint32_t y_offset = x_offset + x_size;
+  constexpr uint32_t y_size = sizeof(Shader::WorkGroup::y);
+  constexpr uint32_t z_offset = y_offset + y_size;
+  constexpr uint32_t z_size = sizeof(Shader::WorkGroup::z);
+
+  constexpr VkSpecializationMapEntry specialization_map_entires[3]{
+    // X
+    {
+      1u,
+      x_offset,
+      x_size,
+    },
+    // Y
+    {
+      2u,
+      y_offset,
+      y_size,
+    },
+    // Z
+    {
+      3u,
+      z_offset,
+      z_size,
+    },
+  };
+
+  const VkSpecializationInfo specialization_info{
+    3u,
+    specialization_map_entires,
+    sizeof(Shader::WorkGroup),
+    &descriptor.work_group,
+  };
+
+  const VkComputePipelineCreateInfo compute_pipeline_create_info{
+    VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
+    nullptr,
+    0u,
+    VkPipelineShaderStageCreateInfo{
+      VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
+      nullptr,
+      0u,
+      VK_SHADER_STAGE_COMPUTE_BIT,
+      descriptor.shader_module,
+      "main",
+      &specialization_info,
+    },
+    descriptor.pipeline_layout,
+    VK_NULL_HANDLE,
+    0u,
+  };
+
+  VkPipeline pipeline{};
+  VK_CHECK(vkCreateComputePipelines(
+      device_, pipeline_cache_.get(), 1u, &compute_pipeline_create_info, nullptr, &pipeline));
+
+  return Handle{
+    pipeline,
+    Deleter(device_),
+  };
+}
+
+} // namespace api
+} // namespace vulkan
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.h b/aten/src/ATen/native/vulkan/api/Pipeline.h
new file mode 100644
index 0000000..a5d7232
--- /dev/null
+++ b/aten/src/ATen/native/vulkan/api/Pipeline.h
@@ -0,0 +1,162 @@
+#pragma once
+
+#include <ATen/native/vulkan/api/Common.h>
+#include <ATen/native/vulkan/api/Cache.h>
+#include <ATen/native/vulkan/api/Shader.h>
+#include <c10/util/hash.h>
+
+namespace at {
+namespace native {
+namespace vulkan {
+namespace api {
+
+//
+// This struct defines pipeline, and pipeline layout, caches intended to minimize
+// redundant object reconstructions at the cost of extra memory consumption.
+//
+// A Vulkan pipeline contains the entirety of states, as one coherent monolithic
+// bundle, required to configure the GPU's execution pipeline.  This usage
+// pattern minimizes driver overhead, promotes pipeline state reuse, and is a
+// departure from, and in direct contrast with, OpenGL's individually confiurable
+// state machine.
+//
+// A Vulkan pipeline layout represents a sequence of Vulkan descriptor sets each
+// having a specific layout, and deterimines the interface between all shader
+// stages and shader resources.  For more information on shaders and shader
+// layouts check the description of at::navie::vulkan::api::Shader.
+//
+// This struct defines the facilities required to create, reuse, and destruct
+// these Vulkan objects.
+//
+
+struct C10_EXPORT Pipeline final {
+  //
+  // Layout
+  //
+
+  struct Layout final {
+    /*
+      Descriptor
+    */
+
+    struct Descriptor final {
+      VkDescriptorSetLayout descriptor_set_layout;
+    };
+
+    /*
+      Factory
+    */
+
+    class Factory final {
+     public:
+      explicit Factory(VkDevice device);
+
+      typedef Layout::Descriptor Descriptor;
+      typedef VK_DELETER(PipelineLayout) Deleter;
+      typedef Handle<VkPipelineLayout, Deleter> Handle;
+
+      struct Hasher {
+        size_t operator()(const Descriptor& descriptor) const;
+      };
+
+      Handle operator()(const Descriptor& descriptor) const;
+
+     private:
+      VkDevice device_;
+    };
+
+    /*
+      Cache
+    */
+
+    typedef api::Cache<Factory> Cache;
+    Cache cache;
+
+    explicit Layout(const VkDevice device)
+      : cache(Factory(device)) {
+    }
+  } layout;
+
+  /*
+    Descriptor
+  */
+
+  struct Descriptor final {
+    VkPipelineLayout pipeline_layout;
+    VkShaderModule shader_module;
+    Shader::WorkGroup work_group;
+  };
+
+  /*
+    Factory
+  */
+
+  class Factory final {
+   public:
+    explicit Factory(VkDevice device);
+
+    typedef Pipeline::Descriptor Descriptor;
+    typedef VK_DELETER(Pipeline) Deleter;
+    typedef Handle<VkPipeline, Deleter> Handle;
+
+    struct Hasher {
+      size_t operator()(const Descriptor& descriptor) const;
+    };
+
+    Handle operator()(const Descriptor& descriptor) const;
+
+   private:
+    VkDevice device_;
+    api::Handle<VkPipelineCache, VK_DELETER(PipelineCache)> pipeline_cache_;
+  };
+
+  /*
+    Cache
+  */
+
+  typedef api::Cache<Factory> Cache;
+  Cache cache;
+
+  explicit Pipeline(const VkDevice device)
+    : layout(device),
+      cache(Factory(device)) {
+  }
+};
+
+//
+// Impl
+//
+
+inline bool operator==(
+    const Pipeline::Layout::Descriptor& _1,
+    const Pipeline::Layout::Descriptor& _2) {
+  return (_1.descriptor_set_layout == _2.descriptor_set_layout);
+}
+
+inline size_t Pipeline::Layout::Factory::Hasher::operator()(
+    const Descriptor& descriptor) const {
+  return c10::get_hash(descriptor.descriptor_set_layout);
+}
+
+inline bool operator==(
+    const Pipeline::Descriptor& _1,
+    const Pipeline::Descriptor& _2) {
+  return (_1.pipeline_layout == _2.pipeline_layout) &&
+         (_1.shader_module == _2.shader_module) &&
+         (_1.work_group == _2.work_group);
+}
+
+inline size_t Pipeline::Factory::Hasher::operator()(
+    const Descriptor& descriptor) const {
+  return c10::get_hash(
+      descriptor.pipeline_layout,
+      descriptor.shader_module,
+      descriptor.work_group.x,
+      descriptor.work_group.y,
+      descriptor.work_group.z);
+}
+
+} // namespace api
+} // namespace vulkan
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/vulkan/api/api.h b/aten/src/ATen/native/vulkan/api/api.h
index ac52202..0a91217 100644
--- a/aten/src/ATen/native/vulkan/api/api.h
+++ b/aten/src/ATen/native/vulkan/api/api.h
@@ -2,4 +2,5 @@
 
 #include <ATen/native/vulkan/api/Common.h>
 #include <ATen/native/vulkan/api/Context.h>
+#include <ATen/native/vulkan/api/Pipeline.h>
 #include <ATen/native/vulkan/api/Shader.h>