| /* |
| * Copyright (c) Meta Platforms, Inc. and affiliates. |
| * All rights reserved. |
| * |
| * This source code is licensed under the BSD-style license found in the |
| * LICENSE file in the root directory of this source tree. |
| */ |
| |
| #include <executorch/backends/vulkan/runtime/vk_api/Shader.h> |
| |
| #include <utility> |
| |
| namespace vkcompute { |
| namespace vkapi { |
| |
| // |
| // ShaderInfo |
| // |
| |
| ShaderInfo::ShaderInfo() |
| : src_code{ |
| nullptr, |
| 0u, |
| } {} |
| |
| ShaderInfo::ShaderInfo( |
| std::string name, |
| const uint32_t* const spirv_bin, |
| const uint32_t size, |
| std::vector<VkDescriptorType> layout, |
| const utils::uvec3 tile_size) |
| : src_code{ |
| spirv_bin, |
| size, |
| }, |
| kernel_name{std::move(name)}, |
| kernel_layout{std::move(layout)}, |
| out_tile_size(tile_size) { |
| } |
| |
| bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) { |
| return ( |
| _1.src_code.bin == _2.src_code.bin && |
| _1.src_code.size == _2.src_code.size); |
| } |
| |
| // |
| // ShaderLayout |
| // |
| |
| ShaderLayout::ShaderLayout( |
| VkDevice device, |
| const ShaderLayout::Signature& signature) |
| : device_(device), handle_{VK_NULL_HANDLE} { |
| std::vector<VkDescriptorSetLayoutBinding> bindings; |
| |
| uint32_t binding_num = 0u; |
| for (const VkDescriptorType type : signature) { |
| bindings.push_back({ |
| binding_num++, // binding |
| type, // descriptorType |
| 1u, // descriptorCount |
| VK_SHADER_STAGE_COMPUTE_BIT, // stageFlags |
| nullptr, // pImmutableSamplers |
| }); |
| } |
| |
| const VkDescriptorSetLayoutCreateInfo descriptor_set_layout_create_info{ |
| VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, // sType |
| nullptr, // pNext |
| 0u, // flags |
| static_cast<uint32_t>(bindings.size()), // bindingCount |
| bindings.data(), // pBindings |
| }; |
| |
| VK_CHECK(vkCreateDescriptorSetLayout( |
| device_, &descriptor_set_layout_create_info, nullptr, &handle_)); |
| } |
| |
| ShaderLayout::ShaderLayout(ShaderLayout&& other) noexcept |
| : device_(other.device_), handle_(other.handle_) { |
| other.handle_ = VK_NULL_HANDLE; |
| } |
| |
| ShaderLayout::~ShaderLayout() { |
| if (VK_NULL_HANDLE == handle_) { |
| return; |
| } |
| vkDestroyDescriptorSetLayout(device_, handle_, nullptr); |
| handle_ = VK_NULL_HANDLE; |
| } |
| |
| void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept { |
| VkDevice tmp_device = lhs.device_; |
| VkDescriptorSetLayout tmp_handle = lhs.handle_; |
| |
| lhs.device_ = rhs.device_; |
| lhs.handle_ = rhs.handle_; |
| |
| rhs.device_ = tmp_device; |
| rhs.handle_ = tmp_handle; |
| } |
| |
| // |
| // ShaderModule |
| // |
| |
| ShaderModule::ShaderModule(VkDevice device, const ShaderInfo& source) |
| : device_(device), handle_{VK_NULL_HANDLE} { |
| const uint32_t* code = source.src_code.bin; |
| uint32_t size = source.src_code.size; |
| |
| const VkShaderModuleCreateInfo shader_module_create_info{ |
| VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType |
| nullptr, // pNext |
| 0u, // flags |
| size, // codeSize |
| code, // pCode |
| }; |
| |
| VK_CHECK(vkCreateShaderModule( |
| device_, &shader_module_create_info, nullptr, &handle_)); |
| } |
| |
| ShaderModule::ShaderModule(ShaderModule&& other) noexcept |
| : device_(other.device_), handle_(other.handle_) { |
| other.handle_ = VK_NULL_HANDLE; |
| } |
| |
| ShaderModule::~ShaderModule() { |
| if (VK_NULL_HANDLE == handle_) { |
| return; |
| } |
| vkDestroyShaderModule(device_, handle_, nullptr); |
| handle_ = VK_NULL_HANDLE; |
| } |
| |
| void swap(ShaderModule& lhs, ShaderModule& rhs) noexcept { |
| VkDevice tmp_device = lhs.device_; |
| VkShaderModule tmp_handle = lhs.handle_; |
| |
| lhs.device_ = rhs.device_; |
| lhs.handle_ = rhs.handle_; |
| |
| rhs.device_ = tmp_device; |
| rhs.handle_ = tmp_handle; |
| } |
| |
| // |
| // ShaderLayoutCache |
| // |
| |
| ShaderLayoutCache::ShaderLayoutCache(VkDevice device) |
| : cache_mutex_{}, device_(device), cache_{} {} |
| |
| ShaderLayoutCache::ShaderLayoutCache(ShaderLayoutCache&& other) noexcept |
| : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) { |
| std::lock_guard<std::mutex> lock(other.cache_mutex_); |
| } |
| |
| ShaderLayoutCache::~ShaderLayoutCache() { |
| purge(); |
| } |
| |
| VkDescriptorSetLayout ShaderLayoutCache::retrieve( |
| const ShaderLayoutCache::Key& key) { |
| std::lock_guard<std::mutex> lock(cache_mutex_); |
| |
| auto it = cache_.find(key); |
| if (cache_.cend() == it) { |
| it = cache_.insert({key, ShaderLayoutCache::Value(device_, key)}).first; |
| } |
| |
| return it->second.handle(); |
| } |
| |
| void ShaderLayoutCache::purge() { |
| std::lock_guard<std::mutex> lock(cache_mutex_); |
| cache_.clear(); |
| } |
| |
| // |
| // ShaderCache |
| // |
| |
| ShaderCache::ShaderCache(VkDevice device) |
| : cache_mutex_{}, device_(device), cache_{} {} |
| |
| ShaderCache::ShaderCache(ShaderCache&& other) noexcept |
| : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) { |
| std::lock_guard<std::mutex> lock(other.cache_mutex_); |
| } |
| |
| ShaderCache::~ShaderCache() { |
| purge(); |
| } |
| |
| VkShaderModule ShaderCache::retrieve(const ShaderCache::Key& key) { |
| std::lock_guard<std::mutex> lock(cache_mutex_); |
| |
| auto it = cache_.find(key); |
| if (cache_.cend() == it) { |
| it = cache_.insert({key, ShaderCache::Value(device_, key)}).first; |
| } |
| |
| return it->second.handle(); |
| } |
| |
| void ShaderCache::purge() { |
| cache_.clear(); |
| } |
| |
| } // namespace vkapi |
| } // namespace vkcompute |