blob: ba7613bd14d3567bd9feb3ac464fcca41c82fc53 [file]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <executorch/backends/vulkan/runtime/api/api.h>
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
namespace vkcompute {
class ComputeGraph;
/*
* Represents a single shader execution op in a ML model.
*/
class DispatchNode final : public ExecuteNode {
friend class ComputeGraph;
public:
explicit DispatchNode(
ComputeGraph& graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
const vkapi::ParamsBindList& params,
const vkapi::SpecVarList& spec_vars = {},
const ResizeFunction& resize_fn = nullptr,
const std::vector<ValueRef>& resize_args = {});
~DispatchNode() override = default;
void encode(ComputeGraph* graph) override;
protected:
const vkapi::ShaderInfo shader_;
const utils::uvec3 global_workgroup_size_;
const utils::uvec3 local_workgroup_size_;
const vkapi::ParamsBindList params_;
const vkapi::SpecVarList spec_vars_;
public:
operator bool() const {
return shader_;
}
};
} // namespace vkcompute