blob: 2cb00ba65afd15ef0995a75eada8dec66bc6bde7 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h>
namespace vkcompute {
ExecuteNode::ExecuteNode(
ComputeGraph& graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
const vkapi::ParamsBindList& params,
const vkapi::SpecVarList& spec_vars,
const ResizeFunction& resize_fn,
const std::vector<ValueRef>& resize_args)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
args_(args),
params_(params),
spec_vars_(spec_vars),
resize_fn_(resize_fn),
resize_args_(resize_args) {
graph.update_descriptor_counts(shader, /*execute = */ true);
}
ExecuteNode::ExecuteNode(
const ResizeFunction& resize_fn,
const std::vector<ValueRef>& resize_args)
: shader_(),
global_workgroup_size_({0u, 0u, 0u}),
local_workgroup_size_({0u, 0u, 0u}),
args_(),
params_(),
spec_vars_(),
resize_fn_(resize_fn),
resize_args_(resize_args) {}
void ExecuteNode::encode(ComputeGraph* graph) {
if (!shader_) {
return;
}
api::Context* const context = graph->context();
vkapi::PipelineBarrier pipeline_barrier{};
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
context->report_shader_dispatch_start(
shader_.kernel_name,
global_workgroup_size_,
local_workgroup_size_,
node_id_);
vkapi::DescriptorSet descriptor_set =
context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);
uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
graph, args_, pipeline_barrier, descriptor_set, idx);
bind_params_to_descriptor_set(params_, descriptor_set, idx);
context->register_shader_dispatch(
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
context->report_shader_dispatch_end();
}
} // namespace vkcompute