blob: 7917304f0c9652cba4b35b40aea0fbaaf1203e11 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
#ifdef USE_VULKAN_API
#include <ATen/native/vulkan/api/Context.h>
#include <ATen/native/vulkan/api/Tensor.h>
#include <ATen/native/vulkan/api/Types.h>
#include <executorch/backends/vulkan/runtime/graph/GraphConfig.h>
#include <executorch/backends/vulkan/runtime/graph/containers/SharedObject.h>
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
#include <executorch/backends/vulkan/runtime/graph/ops/PrepackNode.h>
namespace at {
namespace native {
namespace vulkan {
// Define valid scalar types that the Value class can accept
template <typename T>
struct is_valid_scalar_type : std::false_type {};
template <>
struct is_valid_scalar_type<int64_t> : std::true_type {};
template <>
struct is_valid_scalar_type<double> : std::true_type {};
template <>
struct is_valid_scalar_type<bool> : std::true_type {};
/*
* This is the core data structure used to execute Vulkan models in graph mode.
* As opposed to ATen/eager mode where a command buffer is encoded every
* inference (since ops are executed with the model), in graph mode the ops that
* compose the model are intended to be parsed only once, upon which a command
* buffer will be encoded. Model inference will then execute the cached command
* buffer without needing to encode a new one.
*/
class ComputeGraph final {
public:
explicit ComputeGraph(GraphConfig config);
ComputeGraph(ComputeGraph&&) = default;
ComputeGraph& operator=(ComputeGraph&&) = default;
~ComputeGraph();
private:
GraphConfig config_;
api::DescriptorPoolConfig prepack_descriptor_counts_;
api::DescriptorPoolConfig execute_descriptor_counts_;
std::unique_ptr<api::Context> context_;
std::vector<SharedObject> shared_objects_;
std::vector<Value> values_;
std::vector<std::unique_ptr<PrepackNode>> prepack_nodes_;
std::vector<std::unique_ptr<ExecuteNode>> execute_nodes_;
std::vector<ValueRef> inputs_;
std::vector<ValueRef> outputs_;
public:
//
// Accessors
//
inline api::Context* context() {
return context_.get();
}
inline std::vector<ValueRef>& inputs() {
return inputs_;
}
inline std::vector<ValueRef>& outputs() {
return outputs_;
}
void update_descriptor_counts(
const api::ShaderInfo& shader_info,
bool execute);
/*
* Returns the value at a particular reference
*/
inline Value& get_val(ValueRef idx) {
return values_[idx];
}
inline const std::vector<int64_t>& get_val_sizes(ValueRef idx) {
Value& val = get_val(idx);
if (val.isTensor()) {
return val.toTensor().sizes();
} else if (val.isTensorRef()) {
return val.toTensorRef().sizes;
}
VK_THROW("Could not get sizes of value with type ", val.type());
}
inline api::ScalarType get_val_dtype(ValueRef idx) {
Value& val = get_val(idx);
if (val.isTensor()) {
return val.toTensor().dtype();
} else if (val.isTensorRef()) {
return val.toTensorRef().dtype;
}
VK_THROW("Could not get dtype of value with type ", val.type());
}
inline std::vector<std::unique_ptr<PrepackNode>>& prepack_nodes() {
return prepack_nodes_;
}
inline std::vector<std::unique_ptr<ExecuteNode>>& execute_nodes() {
return execute_nodes_;
}
//
// Graph Building
//
ValueRef add_tensor(
const std::vector<int64_t>& sizes,
const api::ScalarType dtype = api::ScalarType::Float,
const int64_t shared_object_idx = -1);
ValueRef add_tensorref(
const std::vector<int64_t>& sizes,
const api::ScalarType dtype,
const void* const data);
ValueRef add_staging(const api::ScalarType dtype, const size_t numel);
template <typename T>
typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
add_scalar_list(std::vector<T>&& values);
template <typename T>
typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
add_scalar(T value);
ValueRef add_string(std::string&& str);
ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);
/*
* Convenience function to add an input tensor along with its staging buffer
*/
inline IOValueRef add_input_tensor(
const std::vector<int64_t>& sizes,
const api::ScalarType dtype,
const int64_t shared_object_idx = -1) {
ValueRef t = add_tensor(sizes, dtype, shared_object_idx);
ValueRef staging = set_input_tensor(t);
return {t, staging};
}
SharedObject& get_shared_object(const int64_t idx);
//
// Graph Preparation
//
void prepare();
//
// Input/Output
//
void
copy_into_staging(const ValueRef idx, const void* data, const size_t numel);
void copy_from_staging(const ValueRef idx, void* data, const size_t numel);
//
// Graph Prepacking
//
void encode_prepack();
void prepack() const;
//
// Graph Execution
//
void encode_execute();
void execute() const;
};
template <typename T>
inline typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
ComputeGraph::add_scalar_list(std::vector<T>&& values) {
ValueRef idx(static_cast<int>(values_.size()));
values_.emplace_back(std::move(values));
return idx;
}
template <typename T>
inline typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
ComputeGraph::add_scalar(T value) {
ValueRef idx(static_cast<int>(values_.size()));
values_.emplace_back(value);
return idx;
}
} // namespace vulkan
} // namespace native
} // namespace at
#endif /* USE_VULKAN_API */