Add Vulkan job dispatch and flush. (#46008)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46008
Test Plan: Imported from OSS
Reviewed By: IvanKobzarev
Differential Revision: D24291507
Pulled By: AshkanAliabadi
fbshipit-source-id: a3d02e76708a38e49398bb71e31bb2ad676d01af
diff --git a/aten/src/ATen/native/vulkan/api/Command.cpp b/aten/src/ATen/native/vulkan/api/Command.cpp
index 4461240..cdf96f8 100644
--- a/aten/src/ATen/native/vulkan/api/Command.cpp
+++ b/aten/src/ATen/native/vulkan/api/Command.cpp
@@ -242,17 +242,21 @@
}
void Command::Buffer::dispatch(
- const Shader::WorkGroup& work_group) {
+ const Shader::WorkGroup& global_work_group) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
command_buffer_,
"This command buffer is in an invalid state! "
"Potential reason: This command buffer is moved from.");
+ static const auto div_round_up = [](const uint32_t n, const uint32_t d) {
+ return (n + d - 1u) / d;
+ };
+
vkCmdDispatch(
command_buffer_,
- work_group.x,
- work_group.y,
- work_group.z);
+ div_round_up(global_work_group.x, bound_.pipeline.local_work_group.x),
+ div_round_up(global_work_group.y, bound_.pipeline.local_work_group.y),
+ div_round_up(global_work_group.z, bound_.pipeline.local_work_group.z));
}
void Command::Buffer::submit(
diff --git a/aten/src/ATen/native/vulkan/api/Command.h b/aten/src/ATen/native/vulkan/api/Command.h
index aaec2df..69c5238 100644
--- a/aten/src/ATen/native/vulkan/api/Command.h
+++ b/aten/src/ATen/native/vulkan/api/Command.h
@@ -31,7 +31,7 @@
void bind(const Pipeline::Object& pipeline);
void bind(const Descriptor::Set& set);
void copy(Resource::Buffer::Object source, Resource::Buffer::Object destination);
- void dispatch(const Shader::WorkGroup& work_group);
+ void dispatch(const Shader::WorkGroup& global_work_group);
void submit(VkQueue queue, Resource::Fence fence = {});
private:
diff --git a/aten/src/ATen/native/vulkan/api/Common.h b/aten/src/ATen/native/vulkan/api/Common.h
index cbd53e8..f08a08b 100644
--- a/aten/src/ATen/native/vulkan/api/Common.h
+++ b/aten/src/ATen/native/vulkan/api/Common.h
@@ -2,11 +2,19 @@
#include <ATen/ATen.h>
+#ifdef USE_VULKAN_SHADERC_RUNTIME
+#include <ATen/native/vulkan/glsl.h>
+#define VK_KERNEL(name) { name##_glsl, }
+#else
+#include <ATen/native/vulkan/spv.h>
+#define VK_KERNEL(name) { name##_spv, name##_spv_len, }
+#endif /* USE_VULKAN_SHADERC_RUNTIME */
+
#ifdef USE_VULKAN_WRAPPER
#include <vulkan_wrapper.h>
#else
#include <vulkan/vulkan.h>
-#endif
+#endif /* USE_VULKAN_WRAPPER */
#define VK_CHECK(function) \
{ \
diff --git a/aten/src/ATen/native/vulkan/api/Context.cpp b/aten/src/ATen/native/vulkan/api/Context.cpp
index d0fa08d..f28d881 100644
--- a/aten/src/ATen/native/vulkan/api/Context.cpp
+++ b/aten/src/ATen/native/vulkan/api/Context.cpp
@@ -78,25 +78,48 @@
} // namespace
-void Context::Deleter::operator()(const VkDevice device) const {
- // No VK_CHECK. Don't want an exception thrown in the destructor.
- vkDeviceWaitIdle(device);
- vkDestroyDevice(device, nullptr);
-}
-
Context::Context(const Adapter& adapter)
: adapter_(adapter),
device_(
create_device(
adapter.handle,
adapter.compute_queue_family_index),
- Deleter{}),
+ &VK_DELETER(Device)),
queue_(acquire_queue(device(), adapter.compute_queue_family_index)),
command_(gpu()),
shader_(gpu()),
pipeline_(gpu()),
descriptor_(gpu()),
resource_(gpu()) {
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+ device_,
+ "Invalid Vulkan device!");
+}
+
+Context::~Context() {
+ try {
+ flush();
+ }
+ catch (const std::exception& e) {
+ LOG(WARNING)
+ << "Vulkan: Context destructor raised an exception! Error: "
+ << e.what();
+ }
+ catch (...) {
+ LOG(WARNING) << "Vulkan: Context destructor raised an unknown exception!";
+ }
+}
+
+void Context::flush() {
+ VK_CHECK(vkDeviceWaitIdle(device()));
+
+ resource().pool.purge();
+ descriptor().pool.purge();
+ command().pool.purge();
+}
+
+bool available() {
+ return context();
}
Context* context() {
@@ -106,6 +129,40 @@
return context;
}
+Descriptor::Set dispatch_prologue(
+ Command::Buffer& command_buffer,
+ const Shader::Layout::Signature& shader_layout_signature,
+ const Shader::Descriptor& shader_descriptor,
+ const Shader::WorkGroup& local_work_group) {
+ Descriptor& descriptor = context()->descriptor();
+ Pipeline& pipeline = context()->pipeline();
+ Shader& shader = context()->shader();
+
+ const Shader::Layout::Object shader_layout =
+ shader.layout.cache.retrieve({
+ shader_layout_signature,
+ });
+
+ command_buffer.bind(
+ pipeline.cache.retrieve({
+ pipeline.layout.cache.retrieve({
+ shader_layout.handle,
+ }),
+ shader.cache.retrieve(shader_descriptor),
+ local_work_group,
+ }));
+
+ return descriptor.pool.allocate(shader_layout);
+}
+
+void dispatch_epilogue(
+ Command::Buffer& command_buffer,
+ const Descriptor::Set& descriptor_set,
+ const Shader::WorkGroup& global_work_group) {
+ command_buffer.bind(descriptor_set);
+ command_buffer.dispatch(global_work_group);
+}
+
} // namespace api
} // namespace vulkan
} // namespace native
diff --git a/aten/src/ATen/native/vulkan/api/Context.h b/aten/src/ATen/native/vulkan/api/Context.h
index c272312..2be8557 100644
--- a/aten/src/ATen/native/vulkan/api/Context.h
+++ b/aten/src/ATen/native/vulkan/api/Context.h
@@ -29,7 +29,7 @@
Context(Context&&) = default;
Context& operator=(const Context&) = delete;
Context& operator=(Context&&) = default;
- ~Context() = default;
+ ~Context();
GPU gpu();
Command& command();
@@ -38,20 +38,31 @@
Descriptor& descriptor();
Resource& resource();
+ // GPU RPC
+
+ template<typename... Arguments>
+ void dispatch(
+ Command::Buffer& command_buffer,
+ const Shader::Layout::Signature& shader_layout_signature,
+ const Shader::Descriptor& shader_descriptor,
+ const Shader::WorkGroup& local_work_group,
+ const Shader::WorkGroup& global_work_group,
+ Arguments&&... arguments);
+
+ // This function is expensive and its use consequential for performance. Only
+ // use this function for debugging or as a short term hack on way to a more
+ // performant solution.
+
+ void flush();
+
private:
VkDevice device();
VkQueue queue();
private:
- class Deleter final {
- public:
- void operator()(VkDevice device) const;
- };
-
- private:
// Construction and destruction order matters. Do not move members around.
Adapter adapter_;
- Handle<VkDevice, Deleter> device_;
+ Handle<VkDevice, decltype(&VK_DELETER(Device))> device_;
VkQueue queue_;
Command command_;
Shader shader_;
@@ -60,6 +71,7 @@
Resource resource_;
};
+bool available();
Context* context();
//
@@ -105,6 +117,62 @@
return queue_;
}
+namespace detail {
+
+template<
+ size_t...Indices,
+ typename ...Arguments>
+inline void bind(
+ Descriptor::Set& descriptor_set,
+ const std::index_sequence<Indices...>,
+ Arguments&&...arguments) {
+ C10_UNUSED const int _[]{
+ (descriptor_set.bind(Indices, arguments), 0)...,
+ };
+}
+
+} // namespace detail
+
+template<typename... Arguments>
+inline void Context::dispatch(
+ Command::Buffer& command_buffer,
+ const Shader::Layout::Signature& shader_layout_signature,
+ const Shader::Descriptor& shader_descriptor,
+ const Shader::WorkGroup& local_work_group,
+ const Shader::WorkGroup& global_work_group,
+ Arguments&&... arguments) {
+ // Forward declaration
+ Descriptor::Set dispatch_prologue(
+ Command::Buffer&,
+ const Shader::Layout::Signature&,
+ const Shader::Descriptor&,
+ const Shader::WorkGroup&);
+
+ // Factor out template parameter independent code to minimize code bloat.
+ Descriptor::Set descriptor_set = dispatch_prologue(
+ command_buffer,
+ shader_layout_signature,
+ shader_descriptor,
+ local_work_group);
+
+ detail::bind(
+ descriptor_set,
+ std::index_sequence_for<Arguments...>{},
+ std::forward<Arguments>(arguments)...);
+
+ // Forward declaration
+ void dispatch_epilogue(
+ Command::Buffer&,
+ const Descriptor::Set&,
+ const Shader::WorkGroup&);
+
+ // Factor out template parameter independent code to minimize code bloat.
+ dispatch_epilogue(
+ command_buffer,
+ descriptor_set,
+ global_work_group);
+}
+
} // namespace api
} // namespace vulkan
} // namespace native
diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.cpp b/aten/src/ATen/native/vulkan/api/Descriptor.cpp
index ab136c0..037f793 100644
--- a/aten/src/ATen/native/vulkan/api/Descriptor.cpp
+++ b/aten/src/ATen/native/vulkan/api/Descriptor.cpp
@@ -119,13 +119,15 @@
Descriptor::Set::Set(
const VkDevice device,
const VkDescriptorPool descriptor_pool,
- const VkDescriptorSetLayout descriptor_set_layout)
+ const Shader::Layout::Object& shader_layout)
: device_(device),
descriptor_set_(
allocate_descriptor_set(
device_,
descriptor_pool,
- descriptor_set_layout)) {
+ shader_layout.handle)),
+ shader_layout_signature_(shader_layout.signature),
+ bindings_{} {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
descriptor_set_,
"Invalid Vulkan descriptor set!");
@@ -156,7 +158,6 @@
Descriptor::Set& Descriptor::Set::bind(
const uint32_t binding,
- const VkDescriptorType type,
const Resource::Buffer::Object& buffer) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
device_,
@@ -165,7 +166,7 @@
update({
binding,
- type,
+ shader_layout_signature_[binding],
{
.buffer = {
buffer.handle,
@@ -180,7 +181,6 @@
Descriptor::Set& Descriptor::Set::bind(
const uint32_t binding,
- const VkDescriptorType type,
const Resource::Image::Object& image) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
device_,
@@ -189,7 +189,7 @@
update({
binding,
- type,
+ shader_layout_signature_[binding],
{
.image = {
image.sampler,
@@ -309,7 +309,7 @@
}
Descriptor::Set Descriptor::Pool::allocate(
- const VkDescriptorSetLayout descriptor_set_layout)
+ const Shader::Layout::Object& shader_layout)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
device_ && descriptor_pool_,
@@ -317,13 +317,13 @@
"Potential reason: This descriptor pool is moved from.");
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
- descriptor_set_layout,
- "Invalid Vulkan descriptor set layout!");
+ shader_layout,
+ "Invalid Vulkan shader layout!");
return Set(
device_,
descriptor_pool_.get(),
- descriptor_set_layout);
+ shader_layout);
}
void Descriptor::Pool::purge() {
diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.h b/aten/src/ATen/native/vulkan/api/Descriptor.h
index 72550c5..c005e1f 100644
--- a/aten/src/ATen/native/vulkan/api/Descriptor.h
+++ b/aten/src/ATen/native/vulkan/api/Descriptor.h
@@ -2,6 +2,7 @@
#include <ATen/native/vulkan/api/Common.h>
#include <ATen/native/vulkan/api/Resource.h>
+#include <ATen/native/vulkan/api/Shader.h>
namespace at {
namespace native {
@@ -58,7 +59,7 @@
Set(
VkDevice device,
VkDescriptorPool descriptor_pool,
- VkDescriptorSetLayout descriptor_set_layout);
+ const Shader::Layout::Object& shader_layout);
Set(const Set&) = delete;
Set& operator=(const Set&) = delete;
Set(Set&&);
@@ -67,12 +68,10 @@
Set& bind(
uint32_t binding,
- VkDescriptorType type,
const Resource::Buffer::Object& buffer);
Set& bind(
uint32_t binding,
- VkDescriptorType type,
const Resource::Image::Object& image);
VkDescriptorSet handle() const;
@@ -92,6 +91,7 @@
private:
VkDevice device_;
VkDescriptorSet descriptor_set_;
+ Shader::Layout::Signature shader_layout_signature_;
struct {
c10::SmallVector<Item, 8u> items;
@@ -112,7 +112,7 @@
Pool& operator=(Pool&&);
~Pool() = default;
- Set allocate(VkDescriptorSetLayout descriptor_set_layout);
+ Set allocate(const Shader::Layout::Object& shader_layout);
void purge();
private:
diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.cpp b/aten/src/ATen/native/vulkan/api/Pipeline.cpp
index d7067c2..93c9543 100644
--- a/aten/src/ATen/native/vulkan/api/Pipeline.cpp
+++ b/aten/src/ATen/native/vulkan/api/Pipeline.cpp
@@ -132,7 +132,7 @@
3u,
specialization_map_entires,
sizeof(Shader::WorkGroup),
- &descriptor.work_group,
+ &descriptor.local_work_group,
};
const VkComputePipelineCreateInfo compute_pipeline_create_info{
diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.h b/aten/src/ATen/native/vulkan/api/Pipeline.h
index b168b91..5e3e419 100644
--- a/aten/src/ATen/native/vulkan/api/Pipeline.h
+++ b/aten/src/ATen/native/vulkan/api/Pipeline.h
@@ -99,7 +99,7 @@
struct Descriptor final {
VkPipelineLayout pipeline_layout;
VkShaderModule shader_module;
- Shader::WorkGroup work_group;
+ Shader::WorkGroup local_work_group;
};
/*
@@ -132,6 +132,7 @@
struct Object final {
VkPipeline handle;
VkPipelineLayout layout;
+ Shader::WorkGroup local_work_group;
operator bool() const;
};
@@ -182,7 +183,7 @@
const Pipeline::Descriptor& _2) {
return (_1.pipeline_layout == _2.pipeline_layout) &&
(_1.shader_module == _2.shader_module) &&
- (_1.work_group == _2.work_group);
+ (_1.local_work_group == _2.local_work_group);
}
inline size_t Pipeline::Factory::Hasher::operator()(
@@ -190,9 +191,14 @@
return c10::get_hash(
descriptor.pipeline_layout,
descriptor.shader_module,
- descriptor.work_group.x,
- descriptor.work_group.y,
- descriptor.work_group.z);
+ descriptor.local_work_group.x,
+ descriptor.local_work_group.y,
+ descriptor.local_work_group.z);
+}
+
+inline Pipeline::Object::operator bool() const {
+ return (VK_NULL_HANDLE != handle) &&
+ (VK_NULL_HANDLE != layout);
}
inline Pipeline::Object Pipeline::Cache::retrieve(
@@ -200,6 +206,7 @@
return {
cache_.retrieve(descriptor),
descriptor.pipeline_layout,
+ descriptor.local_work_group,
};
}
@@ -207,11 +214,6 @@
cache_.purge();
}
-inline Pipeline::Object::operator bool() const {
- return (VK_NULL_HANDLE != handle) &&
- (VK_NULL_HANDLE != layout);
-}
-
} // namespace api
} // namespace vulkan
} // namespace native
diff --git a/aten/src/ATen/native/vulkan/api/Runtime.cpp b/aten/src/ATen/native/vulkan/api/Runtime.cpp
index ce6e3b4..ce5fdf0 100644
--- a/aten/src/ATen/native/vulkan/api/Runtime.cpp
+++ b/aten/src/ATen/native/vulkan/api/Runtime.cpp
@@ -323,10 +323,6 @@
return runtime.get();
}
-bool available() {
- return initialize();
-}
-
Runtime* runtime() {
Runtime* const runtime = initialize();
TORCH_CHECK(
diff --git a/aten/src/ATen/native/vulkan/api/Runtime.h b/aten/src/ATen/native/vulkan/api/Runtime.h
index ffc031f..4d06da8 100644
--- a/aten/src/ATen/native/vulkan/api/Runtime.h
+++ b/aten/src/ATen/native/vulkan/api/Runtime.h
@@ -55,7 +55,6 @@
Handle<VkDebugReportCallbackEXT, Debug> debug_report_callback_;
};
-bool available();
Runtime* runtime();
} // namespace api
diff --git a/aten/src/ATen/native/vulkan/api/Shader.cpp b/aten/src/ATen/native/vulkan/api/Shader.cpp
index e3c336d..2c090d0 100644
--- a/aten/src/ATen/native/vulkan/api/Shader.cpp
+++ b/aten/src/ATen/native/vulkan/api/Shader.cpp
@@ -9,7 +9,6 @@
namespace vulkan {
namespace api {
-
Shader::Layout::Factory::Factory(const GPU& gpu)
: device_(gpu.device) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
@@ -19,12 +18,25 @@
Shader::Layout::Factory::Handle Shader::Layout::Factory::operator()(
const Descriptor& descriptor) const {
+ c10::SmallVector<VkDescriptorSetLayoutBinding, 8u> bindings;
+
+ uint32_t binding = 0u;
+ for (const VkDescriptorType type : descriptor.signature) {
+ bindings.push_back({
+ binding++,
+ type,
+ 1u,
+ VK_SHADER_STAGE_COMPUTE_BIT,
+ nullptr,
+ });
+ }
+
const VkDescriptorSetLayoutCreateInfo descriptor_set_layout_create_info{
VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO,
nullptr,
0u,
- static_cast<uint32_t>(descriptor.bindings.size()),
- descriptor.bindings.data(),
+ static_cast<uint32_t>(bindings.size()),
+ bindings.data(),
};
VkDescriptorSetLayout descriptor_set_layout{};
@@ -44,6 +56,10 @@
};
}
+Shader::Layout::Cache::Cache(Factory factory)
+ : cache_(std::move(factory)) {
+}
+
#ifdef USE_VULKAN_SHADERC_RUNTIME
struct Shader::Factory::Compiler final {
diff --git a/aten/src/ATen/native/vulkan/api/Shader.h b/aten/src/ATen/native/vulkan/api/Shader.h
index 599ed79..8de369c 100644
--- a/aten/src/ATen/native/vulkan/api/Shader.h
+++ b/aten/src/ATen/native/vulkan/api/Shader.h
@@ -39,11 +39,17 @@
struct Layout final {
/*
+ Signature
+ */
+
+ typedef c10::SmallVector<VkDescriptorType, 8u> Signature;
+
+ /*
Descriptor
*/
struct Descriptor final {
- c10::SmallVector<VkDescriptorSetLayoutBinding, 16u> bindings;
+ Signature signature;
};
/*
@@ -68,12 +74,32 @@
VkDevice device_;
};
+ struct Object final {
+ VkDescriptorSetLayout handle;
+ Signature signature;
+
+ operator bool() const;
+ };
+
/*
Cache
*/
- typedef api::Cache<Factory> Cache;
- Cache cache;
+ class Cache final {
+ public:
+ explicit Cache(Factory factory);
+ Cache(const Cache&) = delete;
+ Cache& operator=(const Cache&) = delete;
+ Cache(Cache&&) = default;
+ Cache& operator=(Cache&&) = default;
+ ~Cache() = default;
+
+ Object retrieve(const Descriptor& descriptor);
+ void purge();
+
+ private:
+ api::Cache<Factory> cache_;
+ } cache;
explicit Layout(const GPU& gpu)
: cache(Factory(gpu)) {
@@ -165,27 +191,38 @@
inline bool operator==(
const Shader::Layout::Descriptor& _1,
const Shader::Layout::Descriptor& _2) {
- return _1.bindings == _2.bindings;
+ return _1.signature == _2.signature;
}
inline size_t Shader::Layout::Factory::Hasher::operator()(
const Descriptor& descriptor) const {
size_t hash = 0u;
- for (const VkDescriptorSetLayoutBinding& binding : descriptor.bindings) {
+ for (const VkDescriptorType type : descriptor.signature) {
hash = c10::hash_combine(
hash,
- c10::get_hash(
- binding.binding,
- binding.descriptorType,
- binding.descriptorCount,
- binding.stageFlags,
- binding.pImmutableSamplers));
+ c10::get_hash(type));
}
return hash;
}
+inline Shader::Layout::Object::operator bool() const {
+ return VK_NULL_HANDLE != handle;
+}
+
+inline Shader::Layout::Object Shader::Layout::Cache::retrieve(
+ const Descriptor& descriptor) {
+ return {
+ cache_.retrieve(descriptor),
+ descriptor.signature,
+ };
+}
+
+inline void Shader::Layout::Cache::purge() {
+ cache_.purge();
+}
+
inline bool operator==(
const Shader::WorkGroup& _1,
const Shader::WorkGroup& _2) {