Vulkan pipeline and pipeline layout cache. (#42395)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42395
Test Plan: Imported from OSS
Reviewed By: IvanKobzarev
Differential Revision: D23252334
Pulled By: AshkanAliabadi
fbshipit-source-id: 6b4e88f9794a7879d47a1cdb671076d50f1944d9
diff --git a/aten/src/ATen/native/vulkan/api/Context.cpp b/aten/src/ATen/native/vulkan/api/Context.cpp
index 075ea5e..47019a9 100644
--- a/aten/src/ATen/native/vulkan/api/Context.cpp
+++ b/aten/src/ATen/native/vulkan/api/Context.cpp
@@ -258,7 +258,8 @@
compute_queue_family_index_(query_compute_queue_family_index(physical_device())),
device_(create_device(physical_device(), compute_queue_family_index_), &VK_DELETER(Device)),
queue_(acquire_queue(device(), compute_queue_family_index_)),
- shader_(device()) {
+ shader_(device()),
+ pipeline_(device()) {
}
Context::Debug::Debug(const VkInstance instance)
diff --git a/aten/src/ATen/native/vulkan/api/Context.h b/aten/src/ATen/native/vulkan/api/Context.h
index d9d0f11..1a4c1ce 100644
--- a/aten/src/ATen/native/vulkan/api/Context.h
+++ b/aten/src/ATen/native/vulkan/api/Context.h
@@ -1,6 +1,7 @@
#pragma once
#include <ATen/native/vulkan/api/Common.h>
+#include <ATen/native/vulkan/api/Pipeline.h>
#include <ATen/native/vulkan/api/Shader.h>
namespace at {
@@ -44,6 +45,10 @@
return shader_;
}
+ inline Pipeline& pipeline() {
+ return pipeline_;
+ }
+
private:
class Debug final {
public:
@@ -64,6 +69,7 @@
Handle<VkDevice, decltype(&VK_DELETER(Device))> device_;
VkQueue queue_;
Shader shader_;
+ Pipeline pipeline_;
};
C10_EXPORT bool available();
diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.cpp b/aten/src/ATen/native/vulkan/api/Pipeline.cpp
new file mode 100644
index 0000000..303eea7
--- /dev/null
+++ b/aten/src/ATen/native/vulkan/api/Pipeline.cpp
@@ -0,0 +1,127 @@
+#include <ATen/native/vulkan/api/Pipeline.h>
+
+namespace at {
+namespace native {
+namespace vulkan {
+namespace api {
+
+Pipeline::Layout::Factory::Factory(const VkDevice device)
+ : device_(device) {
+}
+
+typename Pipeline::Layout::Factory::Handle Pipeline::Layout::Factory::operator()(
+ const Descriptor& descriptor) const {
+ const VkPipelineLayoutCreateInfo pipeline_layout_create_info{
+ VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
+ nullptr,
+ 0u,
+ 1u,
+ &descriptor.descriptor_set_layout,
+ 0u,
+ nullptr,
+ };
+
+ VkPipelineLayout pipeline_layout{};
+ VK_CHECK(vkCreatePipelineLayout(
+ device_, &pipeline_layout_create_info, nullptr, &pipeline_layout));
+
+ return Handle{
+ pipeline_layout,
+ Deleter(device_),
+ };
+}
+
+namespace {
+
+VkPipelineCache create_pipeline_cache(const VkDevice device) {
+ const VkPipelineCacheCreateInfo pipeline_cache_create_info{
+ VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO,
+ nullptr,
+ 0u,
+ 0u,
+ nullptr,
+ };
+
+ VkPipelineCache pipeline_cache{};
+ VK_CHECK(vkCreatePipelineCache(
+ device, &pipeline_cache_create_info, nullptr, &pipeline_cache));
+
+ return pipeline_cache;
+}
+
+} // namespace
+
+Pipeline::Factory::Factory(const VkDevice device)
+ : device_(device),
+ pipeline_cache_(create_pipeline_cache(device), VK_DELETER(PipelineCache)(device)) {
+}
+
+typename Pipeline::Factory::Handle Pipeline::Factory::operator()(
+ const Descriptor& descriptor) const {
+ constexpr uint32_t x_offset = 0u;
+ constexpr uint32_t x_size = sizeof(Shader::WorkGroup::x);
+ constexpr uint32_t y_offset = x_offset + x_size;
+ constexpr uint32_t y_size = sizeof(Shader::WorkGroup::y);
+ constexpr uint32_t z_offset = y_offset + y_size;
+ constexpr uint32_t z_size = sizeof(Shader::WorkGroup::z);
+
+ constexpr VkSpecializationMapEntry specialization_map_entires[3]{
+ // X
+ {
+ 1u,
+ x_offset,
+ x_size,
+ },
+ // Y
+ {
+ 2u,
+ y_offset,
+ y_size,
+ },
+ // Z
+ {
+ 3u,
+ z_offset,
+ z_size,
+ },
+ };
+
+ const VkSpecializationInfo specialization_info{
+ 3u,
+ specialization_map_entires,
+ sizeof(Shader::WorkGroup),
+ &descriptor.work_group,
+ };
+
+ const VkComputePipelineCreateInfo compute_pipeline_create_info{
+ VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
+ nullptr,
+ 0u,
+ VkPipelineShaderStageCreateInfo{
+ VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
+ nullptr,
+ 0u,
+ VK_SHADER_STAGE_COMPUTE_BIT,
+ descriptor.shader_module,
+ "main",
+ &specialization_info,
+ },
+ descriptor.pipeline_layout,
+ VK_NULL_HANDLE,
+ 0u,
+ };
+
+ VkPipeline pipeline{};
+ VK_CHECK(vkCreateComputePipelines(
+ device_, pipeline_cache_.get(), 1u, &compute_pipeline_create_info, nullptr, &pipeline));
+
+ return Handle{
+ pipeline,
+ Deleter(device_),
+ };
+}
+
+} // namespace api
+} // namespace vulkan
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.h b/aten/src/ATen/native/vulkan/api/Pipeline.h
new file mode 100644
index 0000000..a5d7232
--- /dev/null
+++ b/aten/src/ATen/native/vulkan/api/Pipeline.h
@@ -0,0 +1,162 @@
+#pragma once
+
+#include <ATen/native/vulkan/api/Common.h>
+#include <ATen/native/vulkan/api/Cache.h>
+#include <ATen/native/vulkan/api/Shader.h>
+#include <c10/util/hash.h>
+
+namespace at {
+namespace native {
+namespace vulkan {
+namespace api {
+
+//
+// This struct defines pipeline, and pipeline layout, caches intended to minimize
+// redundant object reconstructions at the cost of extra memory consumption.
+//
+// A Vulkan pipeline contains the entirety of states, as one coherent monolithic
+// bundle, required to configure the GPU's execution pipeline. This usage
+// pattern minimizes driver overhead, promotes pipeline state reuse, and is a
+// departure from, and in direct contrast with, OpenGL's individually confiurable
+// state machine.
+//
+// A Vulkan pipeline layout represents a sequence of Vulkan descriptor sets each
+// having a specific layout, and deterimines the interface between all shader
+// stages and shader resources. For more information on shaders and shader
+// layouts check the description of at::navie::vulkan::api::Shader.
+//
+// This struct defines the facilities required to create, reuse, and destruct
+// these Vulkan objects.
+//
+
+struct C10_EXPORT Pipeline final {
+ //
+ // Layout
+ //
+
+ struct Layout final {
+ /*
+ Descriptor
+ */
+
+ struct Descriptor final {
+ VkDescriptorSetLayout descriptor_set_layout;
+ };
+
+ /*
+ Factory
+ */
+
+ class Factory final {
+ public:
+ explicit Factory(VkDevice device);
+
+ typedef Layout::Descriptor Descriptor;
+ typedef VK_DELETER(PipelineLayout) Deleter;
+ typedef Handle<VkPipelineLayout, Deleter> Handle;
+
+ struct Hasher {
+ size_t operator()(const Descriptor& descriptor) const;
+ };
+
+ Handle operator()(const Descriptor& descriptor) const;
+
+ private:
+ VkDevice device_;
+ };
+
+ /*
+ Cache
+ */
+
+ typedef api::Cache<Factory> Cache;
+ Cache cache;
+
+ explicit Layout(const VkDevice device)
+ : cache(Factory(device)) {
+ }
+ } layout;
+
+ /*
+ Descriptor
+ */
+
+ struct Descriptor final {
+ VkPipelineLayout pipeline_layout;
+ VkShaderModule shader_module;
+ Shader::WorkGroup work_group;
+ };
+
+ /*
+ Factory
+ */
+
+ class Factory final {
+ public:
+ explicit Factory(VkDevice device);
+
+ typedef Pipeline::Descriptor Descriptor;
+ typedef VK_DELETER(Pipeline) Deleter;
+ typedef Handle<VkPipeline, Deleter> Handle;
+
+ struct Hasher {
+ size_t operator()(const Descriptor& descriptor) const;
+ };
+
+ Handle operator()(const Descriptor& descriptor) const;
+
+ private:
+ VkDevice device_;
+ api::Handle<VkPipelineCache, VK_DELETER(PipelineCache)> pipeline_cache_;
+ };
+
+ /*
+ Cache
+ */
+
+ typedef api::Cache<Factory> Cache;
+ Cache cache;
+
+ explicit Pipeline(const VkDevice device)
+ : layout(device),
+ cache(Factory(device)) {
+ }
+};
+
+//
+// Impl
+//
+
+inline bool operator==(
+ const Pipeline::Layout::Descriptor& _1,
+ const Pipeline::Layout::Descriptor& _2) {
+ return (_1.descriptor_set_layout == _2.descriptor_set_layout);
+}
+
+inline size_t Pipeline::Layout::Factory::Hasher::operator()(
+ const Descriptor& descriptor) const {
+ return c10::get_hash(descriptor.descriptor_set_layout);
+}
+
+inline bool operator==(
+ const Pipeline::Descriptor& _1,
+ const Pipeline::Descriptor& _2) {
+ return (_1.pipeline_layout == _2.pipeline_layout) &&
+ (_1.shader_module == _2.shader_module) &&
+ (_1.work_group == _2.work_group);
+}
+
+inline size_t Pipeline::Factory::Hasher::operator()(
+ const Descriptor& descriptor) const {
+ return c10::get_hash(
+ descriptor.pipeline_layout,
+ descriptor.shader_module,
+ descriptor.work_group.x,
+ descriptor.work_group.y,
+ descriptor.work_group.z);
+}
+
+} // namespace api
+} // namespace vulkan
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/vulkan/api/api.h b/aten/src/ATen/native/vulkan/api/api.h
index ac52202..0a91217 100644
--- a/aten/src/ATen/native/vulkan/api/api.h
+++ b/aten/src/ATen/native/vulkan/api/api.h
@@ -2,4 +2,5 @@
#include <ATen/native/vulkan/api/Common.h>
#include <ATen/native/vulkan/api/Context.h>
+#include <ATen/native/vulkan/api/Pipeline.h>
#include <ATen/native/vulkan/api/Shader.h>