blob: dece9ddb50da0ff406626b66e9131f2deb99e456 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <executorch/backends/vulkan/runtime/api/api.h>
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
namespace vkcompute {
class ComputeGraph;
/*
* Represents a group of shader arguments (images and/or buffers), with a common
* access permission.
*/
struct ArgGroup {
ArgGroup(const ValueRef ref, const vkapi::MemoryAccessType access)
: refs{ref}, access(access) {}
ArgGroup(
const std::vector<ValueRef>& refs,
const vkapi::MemoryAccessType access)
: refs(refs), access(access) {}
const std::vector<ValueRef> refs;
const vkapi::MemoryAccessType access;
};
/*
* Represents a single execution op in a ML model. In graph mode, ops will be
* implemented in a derived class that implements encode, which will implement
* encoding of the shader corresponding to the op into the command buffer of a
* ComputeGraph.
*/
class ExecuteNode final {
friend class ComputeGraph;
public:
using ResizeFunction = const std::function<void(
ComputeGraph*,
const std::vector<ArgGroup>&,
const std::vector<ValueRef>&)>;
explicit ExecuteNode(
ComputeGraph& graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
const vkapi::ParamsBindList& params,
const vkapi::SpecVarList& spec_vars = {},
const ResizeFunction& resize_fn = nullptr,
const std::vector<ValueRef>& resize_args = {});
/*
* This overload of the ExecuteNode constructor is used to register ops which
* update a tensor view. No shader is dispatched, but the node still needs to
* update the view's sizes and strides after a resize.
*/
explicit ExecuteNode(
const ResizeFunction& resize_fn = nullptr,
const std::vector<ValueRef>& resize_args = {});
~ExecuteNode() = default;
void encode(ComputeGraph* graph);
inline void trigger_resize(ComputeGraph* graph) {
if (resize_fn_ != nullptr) {
resize_fn_(graph, args_, resize_args_);
}
}
inline void set_node_id(uint32_t node_id) {
node_id_ = node_id;
}
protected:
uint32_t node_id_;
const vkapi::ShaderInfo shader_;
const utils::uvec3 global_workgroup_size_;
const utils::uvec3 local_workgroup_size_;
const std::vector<ArgGroup> args_;
const vkapi::ParamsBindList params_;
const vkapi::SpecVarList spec_vars_;
const ResizeFunction resize_fn_;
const std::vector<ValueRef> resize_args_;
public:
operator bool() const {
return shader_;
}
};
} // namespace vkcompute