blob: 713fd9917e792d92206353d59600263251b202e2 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/backends/vulkan/runtime/vk_api/Adapter.h>
#include <executorch/backends/vulkan/runtime/vk_api/Command.h>
#include <mutex>
namespace vkcompute {
namespace vkapi {
//
// CommandBuffer
//
CommandBuffer::CommandBuffer(
VkCommandBuffer handle,
const VkCommandBufferUsageFlags flags)
: handle_(handle),
flags_(flags),
state_(CommandBuffer::State::NEW),
bound_{} {}
CommandBuffer::CommandBuffer(CommandBuffer&& other) noexcept
: handle_(other.handle_),
flags_(other.flags_),
state_(CommandBuffer::State::INVALID),
bound_(other.bound_) {
other.handle_ = VK_NULL_HANDLE;
other.bound_.reset();
}
CommandBuffer& CommandBuffer::operator=(CommandBuffer&& other) noexcept {
handle_ = other.handle_;
flags_ = other.flags_;
state_ = other.state_;
bound_ = other.bound_;
other.handle_ = VK_NULL_HANDLE;
other.bound_.reset();
other.state_ = CommandBuffer::State::INVALID;
return *this;
}
void CommandBuffer::begin() {
VK_CHECK_COND(
state_ == CommandBuffer::State::NEW,
"Vulkan CommandBuffer: called begin() on a command buffer whose state "
"is not NEW.");
const VkCommandBufferBeginInfo begin_info{
VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
nullptr,
flags_,
nullptr,
};
VK_CHECK(vkBeginCommandBuffer(handle_, &begin_info));
state_ = CommandBuffer::State::RECORDING;
}
void CommandBuffer::end() {
VK_CHECK_COND(
state_ == CommandBuffer::State::RECORDING ||
state_ == CommandBuffer::State::SUBMITTED,
"Vulkan CommandBuffer: called end() on a command buffer whose state "
"is not RECORDING or SUBMITTED.");
if (state_ == CommandBuffer::State::RECORDING) {
VK_CHECK(vkEndCommandBuffer(handle_));
}
state_ = CommandBuffer::State::READY;
}
void CommandBuffer::bind_pipeline(
VkPipeline pipeline,
VkPipelineLayout pipeline_layout,
const utils::uvec3 local_workgroup_size) {
VK_CHECK_COND(
state_ == CommandBuffer::State::RECORDING,
"Vulkan CommandBuffer: called bind_pipeline() on a command buffer whose state "
"is not RECORDING.");
if (pipeline != bound_.pipeline) {
vkCmdBindPipeline(handle_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
bound_.pipeline = pipeline;
}
bound_.pipeline_layout = pipeline_layout;
bound_.local_workgroup_size = local_workgroup_size;
state_ = CommandBuffer::State::PIPELINE_BOUND;
}
void CommandBuffer::bind_descriptors(VkDescriptorSet descriptors) {
VK_CHECK_COND(
state_ == CommandBuffer::State::PIPELINE_BOUND,
"Vulkan CommandBuffer: called bind_descriptors() on a command buffer whose state "
"is not PIPELINE_BOUND.");
if (descriptors != bound_.descriptors) {
vkCmdBindDescriptorSets(
handle_, // commandBuffer
VK_PIPELINE_BIND_POINT_COMPUTE, // pipelineBindPoint
bound_.pipeline_layout, // layout
0u, // firstSet
1u, // descriptorSetCount
&descriptors, // pDescriptorSets
0u, // dynamicOffsetCount
nullptr); // pDynamicOffsets
}
bound_.descriptors = descriptors;
state_ = CommandBuffer::State::DESCRIPTORS_BOUND;
}
void CommandBuffer::insert_barrier(PipelineBarrier& pipeline_barrier) {
VK_CHECK_COND(
state_ == CommandBuffer::State::DESCRIPTORS_BOUND ||
state_ == CommandBuffer::State::RECORDING,
"Vulkan CommandBuffer: called insert_barrier() on a command buffer whose state "
"is not DESCRIPTORS_BOUND or RECORDING.");
if (pipeline_barrier) {
if (!pipeline_barrier.buffer_barrier_handles.empty()) {
pipeline_barrier.buffer_barrier_handles.clear();
}
for (const BufferMemoryBarrier& memory_barrier : pipeline_barrier.buffers) {
pipeline_barrier.buffer_barrier_handles.push_back(memory_barrier.handle);
}
if (!pipeline_barrier.image_barrier_handles.empty()) {
pipeline_barrier.image_barrier_handles.clear();
}
for (const ImageMemoryBarrier& memory_barrier : pipeline_barrier.images) {
pipeline_barrier.image_barrier_handles.push_back(memory_barrier.handle);
}
vkCmdPipelineBarrier(
handle_, // commandBuffer
pipeline_barrier.stage.src, // srcStageMask
pipeline_barrier.stage.dst, // dstStageMask
0u, // dependencyFlags
0u, // memoryBarrierCount
nullptr, // pMemoryBarriers
pipeline_barrier.buffers.size(), // bufferMemoryBarrierCount
!pipeline_barrier.buffers.empty()
? pipeline_barrier.buffer_barrier_handles.data()
: nullptr, // pMemoryBarriers
pipeline_barrier.images.size(), // imageMemoryBarrierCount
!pipeline_barrier.images.empty()
? pipeline_barrier.image_barrier_handles.data()
: nullptr); // pImageMemoryBarriers
}
state_ = CommandBuffer::State::BARRIERS_INSERTED;
}
void CommandBuffer::dispatch(const utils::uvec3& global_workgroup_size) {
VK_CHECK_COND(
state_ == CommandBuffer::State::BARRIERS_INSERTED,
"Vulkan CommandBuffer: called dispatch() on a command buffer whose state "
"is not BARRIERS_INSERTED.");
vkCmdDispatch(
handle_,
utils::div_up(global_workgroup_size[0u], bound_.local_workgroup_size[0u]),
utils::div_up(global_workgroup_size[1u], bound_.local_workgroup_size[1u]),
utils::div_up(
global_workgroup_size[2u], bound_.local_workgroup_size[2u]));
state_ = CommandBuffer::State::RECORDING;
}
void CommandBuffer::write_timestamp(VkQueryPool querypool, const uint32_t idx)
const {
VK_CHECK_COND(
state_ == CommandBuffer::State::RECORDING,
"Vulkan CommandBuffer: called write_timestamp() on a command buffer whose state "
"is not RECORDING.");
vkCmdWriteTimestamp(
handle_, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, querypool, idx);
}
void CommandBuffer::reset_querypool(
VkQueryPool querypool,
const uint32_t first_idx,
const uint32_t count) const {
VK_CHECK_COND(
state_ == CommandBuffer::State::RECORDING,
"Vulkan CommandBuffer: called reset_querypool() on a command buffer whose state "
"is not RECORDING.");
vkCmdResetQueryPool(handle_, querypool, first_idx, count);
}
VkCommandBuffer CommandBuffer::get_submit_handle(const bool final_use) {
VK_CHECK_COND(
state_ == CommandBuffer::State::READY,
"Vulkan CommandBuffer: called begin() on a command buffer whose state "
"is not READY.");
VkCommandBuffer handle = handle_;
if (!is_reusable() || final_use) {
invalidate();
}
state_ = CommandBuffer::State::SUBMITTED;
return handle;
}
//
// CommandPool
//
CommandPool::CommandPool(
VkDevice device,
const uint32_t queue_family_idx,
const CommandPoolConfig& config)
: device_(device),
queue_family_idx_(queue_family_idx),
pool_(VK_NULL_HANDLE),
config_(config),
mutex_{},
buffers_{},
in_use_(0u) {
const VkCommandPoolCreateInfo create_info{
VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO,
nullptr,
VK_COMMAND_POOL_CREATE_TRANSIENT_BIT,
queue_family_idx_,
};
VK_CHECK(vkCreateCommandPool(device_, &create_info, nullptr, &pool_));
// Pre-allocate some command buffers
allocate_new_batch(config_.cmd_pool_initial_size);
}
CommandPool::~CommandPool() {
if (VK_NULL_HANDLE == pool_) {
return;
}
vkDestroyCommandPool(device_, pool_, nullptr);
}
CommandBuffer CommandPool::get_new_cmd(bool reusable) {
std::lock_guard<std::mutex> lock(mutex_);
// No-ops if there are command buffers available
allocate_new_batch(config_.cmd_pool_batch_size);
VkCommandBuffer handle = buffers_[in_use_];
VkCommandBufferUsageFlags cmd_flags = 0u;
if (!reusable) {
cmd_flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
}
in_use_++;
return CommandBuffer(handle, cmd_flags);
}
void CommandPool::flush() {
std::lock_guard<std::mutex> lock(mutex_);
VK_CHECK(vkResetCommandPool(device_, pool_, 0u));
in_use_ = 0u;
}
void CommandPool::allocate_new_batch(const uint32_t count) {
// No-ops if there are still command buffers available
if (in_use_ < buffers_.size()) {
return;
}
buffers_.resize(buffers_.size() + count);
const VkCommandBufferAllocateInfo allocate_info{
VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO, // sType
nullptr, // pNext
pool_, // commandPool
VK_COMMAND_BUFFER_LEVEL_PRIMARY, // level
count, // commandBufferCount
};
VK_CHECK(vkAllocateCommandBuffers(
device_, &allocate_info, buffers_.data() + in_use_));
}
} // namespace vkapi
} // namespace vkcompute