[vulkan] Refactor Descriptor::Pool (#80727)

Part of a refactor of the Vulkan codebase.

Differential Revision: [D37125677](https://our.internmc.facebook.com/intern/diff/D37125677/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80727
Approved by: https://github.com/kimishpatel
diff --git a/aten/src/ATen/native/vulkan/api/Context.cpp b/aten/src/ATen/native/vulkan/api/Context.cpp
index 60b2e58..aec7b93 100644
--- a/aten/src/ATen/native/vulkan/api/Context.cpp
+++ b/aten/src/ATen/native/vulkan/api/Context.cpp
@@ -20,7 +20,7 @@
       queue_(adapter_p_->request_queue()),
       // Resource pools
       command_pool_(device_, queue_.family_index, config_.cmdPoolConfig),
-      descriptor_(gpu()),
+      descriptor_pool_(device_, config_.descriptorPoolConfig),
       fences_(device_),
       querypool_(
         device_,
@@ -63,7 +63,8 @@
   command_buffer.bind_pipeline(
       pipeline, pipeline_layout, local_workgroup_size);
 
-  return descriptor().pool.allocate(shader_layout, shader_layout_signature);
+  return descriptor_pool().get_descriptor_set(
+      shader_layout, shader_layout_signature);
 }
 
 void Context::submit_compute_epilogue(
@@ -127,7 +128,7 @@
   VK_CHECK(vkQueueWaitIdle(queue()));
 
   command_pool_.flush();
-  descriptor().pool.purge();
+  descriptor_pool_.flush();
 
   std::lock_guard<std::mutex> bufferlist_lock(buffer_clearlist_mutex_);
   std::lock_guard<std::mutex> imagelist_lock(image_clearlist_mutex_);
@@ -142,12 +143,24 @@
 Context* context() {
   static const std::unique_ptr<Context> context([]() -> Context* {
     try {
+      const CommandPoolConfig cmd_config{
+        32u,  // cmdPoolInitialSize
+        8u,  // cmdPoolBatchSize
+      };
+
+      const DescriptorPoolConfig descriptor_pool_config{
+        1024u,  // descriptorPoolMaxSets
+        1024u,  // descriptorUniformBufferCount
+        1024u,  // descriptorStorageBufferCount
+        1024u,  // descriptorCombinedSamplerCount
+        1024u,  // descriptorStorageImageCount
+        32u,  // descriptorPileSizes
+      };
+
       const ContextConfig config{
         16u,  // cmdSubmitFrequency
-        {  // cmdPoolConfig
-          32u,  // cmdPoolInitialSize
-          8u,  // cmdPoolBatchSize
-        },
+        cmd_config,  // cmdPoolConfig
+        descriptor_pool_config,  // descriptorPoolConfig
       };
       return new Context(
           runtime()->instance(), runtime()->default_adapter_i(), config);
diff --git a/aten/src/ATen/native/vulkan/api/Context.h b/aten/src/ATen/native/vulkan/api/Context.h
index e402419..125d4a2 100644
--- a/aten/src/ATen/native/vulkan/api/Context.h
+++ b/aten/src/ATen/native/vulkan/api/Context.h
@@ -20,6 +20,7 @@
 struct ContextConfig final {
   uint32_t cmdSubmitFrequency;
   CommandPoolConfig cmdPoolConfig;
+  DescriptorPoolConfig descriptorPoolConfig;
 };
 
 //
@@ -54,7 +55,7 @@
   Adapter::Queue queue_;
   // Resource Pools
   CommandPool command_pool_;
-  Descriptor descriptor_;
+  DescriptorPool descriptor_pool_;
   FencePool fences_;
   QueryPool querypool_;
   // Command buffers submission
@@ -115,8 +116,8 @@
 
   // Resource Pools
 
-  inline Descriptor& descriptor() {
-    return descriptor_;
+  inline DescriptorPool& descriptor_pool() {
+    return descriptor_pool_;
   }
 
   inline FencePool& fences() {
diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.cpp b/aten/src/ATen/native/vulkan/api/Descriptor.cpp
index 8d65e40..f8c44bf 100644
--- a/aten/src/ATen/native/vulkan/api/Descriptor.cpp
+++ b/aten/src/ATen/native/vulkan/api/Descriptor.cpp
@@ -137,214 +137,140 @@
   }
 }
 
-namespace {
+//
+// DescriptorSetPile
+//
 
-VkDescriptorPool create_descriptor_pool(const VkDevice device) {
-  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
-      device,
-      "Invalid Vulkan device!");
-
-  const struct {
-    uint32_t capacity;
-    c10::SmallVector<VkDescriptorPoolSize, 4u> sizes;
-  } descriptor {
-    1024u,
-    {
-      /*
-        Buffers
-      */
-
-      {
-        VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
-        1024u,
-      },
-      {
-        VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
-        1024u,
-      },
-
-      /*
-        Images
-      */
-
-      {
-        VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
-        1024u,
-      },
-      {
-        VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
-        1024u,
-      },
-    },
-  };
-
-  const VkDescriptorPoolCreateInfo descriptor_pool_create_info{
-    VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO,
-    nullptr,
-    0u,
-    descriptor.capacity,
-    static_cast<uint32_t>(descriptor.sizes.size()),
-    descriptor.sizes.data(),
-  };
-
-  VkDescriptorPool descriptor_pool{};
-  VK_CHECK(vkCreateDescriptorPool(
-      device,
-      &descriptor_pool_create_info,
-      nullptr,
-      &descriptor_pool));
-
-  TORCH_CHECK(
-      descriptor_pool,
-      "Invalid Vulkan descriptor pool!");
-
-  return descriptor_pool;
+DescriptorSetPile::DescriptorSetPile(
+    const uint32_t pile_size,
+    const VkDescriptorSetLayout descriptor_set_layout,
+    const VkDevice device,
+    const VkDescriptorPool descriptor_pool)
+  : pile_size_{pile_size},
+    set_layout_{descriptor_set_layout},
+    device_{device},
+    pool_{descriptor_pool},
+    descriptors_{},
+    in_use_(0u) {
+  descriptors_.resize(pile_size_);
+  allocate_new_batch();
 }
 
-void allocate_descriptor_sets(
-    const VkDevice device,
-    const VkDescriptorPool descriptor_pool,
-    const VkDescriptorSetLayout descriptor_set_layout,
-    VkDescriptorSet* const descriptor_sets,
-    const uint32_t count) {
-  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
-      device,
-      "Invalid Vulkan device!");
+VkDescriptorSet DescriptorSetPile::get_descriptor_set() {
+  // No-ops if there are descriptor sets available
+  allocate_new_batch();
 
-  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
-      descriptor_pool,
-      "Invalid Vulkan descriptor pool!");
+  const VkDescriptorSet handle = descriptors_[in_use_];
+  descriptors_[in_use_] = VK_NULL_HANDLE;
 
-  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
-      descriptor_set_layout,
-      "Invalid Vulkan descriptor set layout!");
+  in_use_++;
+  return handle;
+}
 
-  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
-      descriptor_sets && (count > 0u),
-      "Invalid usage!");
+void DescriptorSetPile::allocate_new_batch() {
+  // No-ops if there are still descriptor sets availble
+  if (in_use_ < descriptors_.size() &&
+      descriptors_[in_use_] != VK_NULL_HANDLE) {
+    return;
+  }
 
-  std::vector<VkDescriptorSetLayout> descriptor_set_layouts(count);
-  fill(
-    descriptor_set_layouts.begin(),
-    descriptor_set_layouts.end(),
-    descriptor_set_layout
-  );
+  std::vector<VkDescriptorSetLayout> layouts(descriptors_.size());
+  fill(layouts.begin(), layouts.end(), set_layout_);
 
-  const VkDescriptorSetAllocateInfo descriptor_set_allocate_info{
-    VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO,
-    nullptr,
-    descriptor_pool,
-    utils::safe_downcast<uint32_t>(descriptor_set_layouts.size()),
-    descriptor_set_layouts.data(),
+  const VkDescriptorSetAllocateInfo allocate_info{
+    VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO,  // sType
+    nullptr, // pNext
+    pool_,  // descriptorPool
+    utils::safe_downcast<uint32_t>(layouts.size()),  // descriptorSetCount
+    layouts.data(),  // pSetLayouts
   };
 
   VK_CHECK(vkAllocateDescriptorSets(
-      device,
-      &descriptor_set_allocate_info,
-      descriptor_sets));
-}
-
-} // namespace
-
-Descriptor::Pool::Pool(const GPU& gpu)
-  : device_(gpu.device),
-    descriptor_pool_(
-        create_descriptor_pool(gpu.device),
-        VK_DELETER(DescriptorPool)(device_)) {
-  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
       device_,
-      "Invalid Vulkan device!");
+      &allocate_info,
+      descriptors_.data()));
 
-  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
-      descriptor_pool_,
-      "Invalid Vulkan descriptor pool!");
+  in_use_ = 0u;
 }
 
-Descriptor::Pool::Pool(Pool&& pool)
-  : device_(std::move(pool.device_)),
-    descriptor_pool_(std::move(pool.descriptor_pool_)),
-    set_(std::move(pool.set_)) {
-  pool.invalidate();
-}
+//
+// DescriptorPool
+//
 
-Descriptor::Pool& Descriptor::Pool::operator=(Pool&& pool) {
-  if (&pool != this) {
-    device_ = std::move(pool.device_);
-    descriptor_pool_ = std::move(pool.descriptor_pool_);
-    set_ = std::move(pool.set_);
-
-    pool.invalidate();
+DescriptorPool::DescriptorPool(
+    const VkDevice device,
+    const DescriptorPoolConfig& config)
+  : device_(device),
+    pool_(VK_NULL_HANDLE),
+    config_(config),
+    mutex_{},
+    piles_{} {
+  c10::SmallVector<VkDescriptorPoolSize, 4u> type_sizes {
+    {
+      VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
+      config_.descriptorUniformBufferCount,
+    },
+    {
+      VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
+      config_.descriptorStorageBufferCount,
+    },
+    {
+      VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
+      config_.descriptorCombinedSamplerCount,
+    },
+    {
+      VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
+      config_.descriptorStorageBufferCount,
+    },
   };
 
-  return *this;
-}
+  const VkDescriptorPoolCreateInfo create_info{
+    VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO,  // sType
+    nullptr,  // pNext
+    0u,  // flags
+    config_.descriptorPoolMaxSets,  // maxSets
+    static_cast<uint32_t>(type_sizes.size()),  // poolSizeCounts
+    type_sizes.data(),  // pPoolSizes
+  };
 
-Descriptor::Pool::~Pool() {
-  try {
-    if (device_ && descriptor_pool_) {
-      purge();
-    }
-  }
-  catch (const std::exception& e) {
-    TORCH_WARN(
-        "Vulkan: Descriptor pool destructor raised an exception! Error: ",
-        e.what());
-  }
-  catch (...) {
-    TORCH_WARN(
-        "Vulkan: Descriptor pool destructor raised an exception! "
-        "Error: Unknown");
-  }
-}
-
-DescriptorSet Descriptor::Pool::allocate(
-    const VkDescriptorSetLayout handle,
-    const ShaderLayout::Signature& signature) {
-  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
-      device_ && descriptor_pool_,
-      "This descriptor pool is in an invalid state! "
-      "Potential reason: This descriptor pool is moved from.");
-
-  auto iterator = set_.layouts.find(handle);
-  if (set_.layouts.cend() == iterator) {
-    iterator = set_.layouts.insert({handle, {}}).first;
-    iterator->second.pool.reserve(Configuration::kReserve);
-  }
-
-  auto& layout = iterator->second;
-
-  if (layout.pool.size() == layout.in_use) {
-    layout.pool.resize(
-        layout.pool.size() +
-        Configuration::kQuantum);
-
-    allocate_descriptor_sets(
-        device_,
-        descriptor_pool_.get(),
-        handle,
-        layout.pool.data() + layout.in_use,
-        Configuration::kQuantum);
-  }
-
-  return DescriptorSet(
+  VK_CHECK(vkCreateDescriptorPool(
       device_,
-      layout.pool[layout.in_use++],
-      signature);
+      &create_info,
+      nullptr,
+      &pool_));
 }
 
-void Descriptor::Pool::purge() {
-  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
-      device_ && descriptor_pool_,
-      "This descriptor pool is in an invalid state! "
-      "Potential reason: This descriptor pool is moved from.");
-
-  VK_CHECK(vkResetDescriptorPool(device_, descriptor_pool_.get(), 0u));
-  set_.layouts.clear();
+DescriptorPool::~DescriptorPool() {
+  if (VK_NULL_HANDLE == pool_) {
+    return;
+  }
+  vkDestroyDescriptorPool(device_, pool_, nullptr);
 }
 
-void Descriptor::Pool::invalidate() {
-  device_ = VK_NULL_HANDLE;
-  descriptor_pool_.reset();
+DescriptorSet DescriptorPool::get_descriptor_set(
+    const VkDescriptorSetLayout set_layout,
+    const ShaderLayout::Signature& signature) {
+  auto it = piles_.find(set_layout);
+  if (piles_.cend() == it) {
+    it = piles_.insert(
+        {
+          set_layout,
+          DescriptorSetPile(
+              config_.descriptorPileSizes,
+              set_layout,
+              device_,
+              pool_),
+        }).first;
+  }
+
+  VkDescriptorSet handle = it->second.get_descriptor_set();
+
+  return DescriptorSet(device_, handle, signature);
+}
+
+void DescriptorPool::flush() {
+  VK_CHECK(vkResetDescriptorPool(device_, pool_, 0u));
+  piles_.clear();
 }
 
 } // namespace api
diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.h b/aten/src/ATen/native/vulkan/api/Descriptor.h
index e08513d..9a17b7b 100644
--- a/aten/src/ATen/native/vulkan/api/Descriptor.h
+++ b/aten/src/ATen/native/vulkan/api/Descriptor.h
@@ -53,90 +53,76 @@
   void add_binding(const ResourceBinding& resource);
 };
 
-//
-// This struct defines caches of descriptor pools, and descriptor sets allocated
-// from those pools, intended to minimize redundant object reconstructions or
-// accelerate unavoidable memory allocations, both at the cost of extra memory
-// consumption.
-//
-// A descriptor set is logically an array of descriptors, each of which
-// references a resource (i.e. buffers and images), in turn telling the core
-// executing the shader, where in GPU, or GPU-accessible system, memory the said
-// resource resides.
-//
-// To accelerate creation of the descriptor sets, modern graphics APIs allocate
-// them from a pool, more elaborately referred to as descriptor pools, which do
-// need to be purged frequently _after_ none of the descriptors the pools contain
-// is in use by the GPU.  Care must be taken that descriptors are not freed while
-// they are in use by the pipeline, which considering the asynchronous nature of
-// CPU-GPU interactions, can be anytime after the command is issued until it is
-// fully executed by the GPU.
-//
-// As you can imagine, it is possible to have multiple descriptor pools, each of
-// which is configured to house different types of descriptor sets with different
-// allocation strategies. These descriptor pools themselves are fairly stable
-// objects in that they theymself should not be created and destroyed frequently.
-// That is the reason why we store them in a cache, which according to our usage
-// of the term 'cache' in this implementatoin, is reserved for objects that are
-// created infrequently and stabilize to a manageable number quickly over the
-// lifetime of the program.
-//
-// Descriptor sets though, on the other hand, are allocated from pools which
-// indeed does mean that the pools must be purged on a regular basis or else
-// they will run out of free items.  Again, this is in line with our usage of
-// the term 'pool' in this implementation which we use to refer to a container
-// of objects that is allocated out of and is required to be frequently purged.
-//
-// It is important to point out that for performance reasons, we intentionally
-// do not free the descriptor sets individually, and instead opt to purge the
-// pool in its totality, even though Vulkan supports the former usage pattern
-// as well.  This behavior is by design.
-//
+class DescriptorSetPile final {
+ public:
+  DescriptorSetPile(
+      const uint32_t,
+      const VkDescriptorSetLayout,
+      const VkDevice,
+      const VkDescriptorPool);
 
-struct Descriptor final {
-  //
-  // Pool
-  //
+  DescriptorSetPile(const DescriptorSetPile&) = delete;
+  DescriptorSetPile& operator=(const DescriptorSetPile&) = delete;
 
-  class Pool final {
-   public:
-    explicit Pool(const GPU& gpu);
-    Pool(const Pool&) = delete;
-    Pool& operator=(const Pool&) = delete;
-    Pool(Pool&&);
-    Pool& operator=(Pool&&);
-    ~Pool();
+  DescriptorSetPile(DescriptorSetPile&&) = default;
+  DescriptorSetPile& operator=(DescriptorSetPile&&) = default;
 
-    DescriptorSet allocate(
-        const VkDescriptorSetLayout handle,
-        const ShaderLayout::Signature& signature);
-    void purge();
+  ~DescriptorSetPile() = default;
 
-   private:
-    void invalidate();
+ private:
+  uint32_t pile_size_;
+  VkDescriptorSetLayout set_layout_;
+  VkDevice device_;
+  VkDescriptorPool pool_;
+  std::vector<VkDescriptorSet> descriptors_;
+  size_t in_use_;
 
-   private:
-    struct Configuration final {
-      static constexpr uint32_t kQuantum = 16u;
-      static constexpr uint32_t kReserve = 64u;
-    };
+ public:
+  VkDescriptorSet get_descriptor_set();
 
-    VkDevice device_;
-    Handle<VkDescriptorPool, VK_DELETER(DescriptorPool)> descriptor_pool_;
+ private:
+  void allocate_new_batch();
+};
 
-    struct {
-      struct Layout final {
-        std::vector<VkDescriptorSet> pool;
-        size_t in_use;
-      };
+struct DescriptorPoolConfig final {
+  // Overall Pool capacity
+  uint32_t descriptorPoolMaxSets;
+  // DescriptorCounts by type
+  uint32_t descriptorUniformBufferCount;
+  uint32_t descriptorStorageBufferCount;
+  uint32_t descriptorCombinedSamplerCount;
+  uint32_t descriptorStorageImageCount;
+  // Pile size for pre-allocating descriptor sets
+  uint32_t descriptorPileSizes;
+};
 
-      ska::flat_hash_map<VkDescriptorSetLayout, Layout> layouts;
-    } set_;
-  } pool /* [thread_count] */;
+class DescriptorPool final {
+ public:
+  explicit DescriptorPool(
+      const VkDevice, const DescriptorPoolConfig&);
 
-  explicit Descriptor(const GPU& gpu)
-    : pool(gpu) {
-  }
+  DescriptorPool(const DescriptorPool&) = delete;
+  DescriptorPool& operator=(const DescriptorPool&) = delete;
+
+  DescriptorPool(DescriptorPool&&) = delete;
+  DescriptorPool& operator=(DescriptorPool&&) = delete;
+
+  ~DescriptorPool();
+
+ private:
+  VkDevice device_;
+  VkDescriptorPool pool_;
+  DescriptorPoolConfig config_;
+  // New Descriptors
+  std::mutex mutex_;
+  ska::flat_hash_map<VkDescriptorSetLayout, DescriptorSetPile> piles_;
+
+ public:
+  DescriptorSet get_descriptor_set(
+      const VkDescriptorSetLayout handle,
+      const ShaderLayout::Signature& signature);
+
+  void flush();
 };
 
 } // namespace api