[ET-VK] Add type for symbolic integers
Differential Revision: D62144399
Pull Request resolved: https://github.com/pytorch/executorch/pull/5040
diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp
index 6c3ec88..a8f57f5 100644
--- a/backends/vulkan/runtime/graph/ComputeGraph.cpp
+++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp
@@ -43,6 +43,7 @@
VALUE_PTR_CLASS_IMPL(DoubleListPtr, std::vector<double>, DoubleList)
VALUE_PTR_CLASS_IMPL(BoolListPtr, std::vector<bool>, BoolList)
VALUE_PTR_CLASS_IMPL(ValueListPtr, std::vector<ValueRef>, ValueList)
+VALUE_PTR_CLASS_IMPL(SymIntPtr, SymInt, SymInt)
#undef VALUE_PTR_CLASS_IMPL
@@ -261,6 +262,13 @@
return idx;
}
+ValueRef ComputeGraph::add_symint(const int32_t val) {
+ ValueRef idx(static_cast<int>(values_.size()));
+ check_no_active_value_ptrs();
+ values_.emplace_back(SymInt(context(), val));
+ return idx;
+}
+
ValueRef ComputeGraph::set_input_tensor(
const ValueRef idx,
const bool use_staging) {
@@ -300,6 +308,22 @@
return idx;
}
+vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
+ const ValueRef idx) {
+ if (values_.at(idx).isInt()) {
+ const int32_t val = extract_scalar<int32_t>(idx);
+ create_params_buffer(val);
+ } else if (values_.at(idx).isSymInt()) {
+ SymIntPtr symint = get_symint(idx);
+ return vkapi::BufferBindInfo(symint->gpu_buffer.buffer());
+ }
+ VK_THROW("Cannot create a int param buffer for the given value");
+}
+
+void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
+ get_symint(idx)->set(val);
+}
+
SharedObject& ComputeGraph::get_shared_object(const int64_t idx) {
if (idx >= shared_objects_.size()) {
shared_objects_.resize(static_cast<size_t>(idx + 1));
diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h
index 9b04b08..ac5e0d6 100644
--- a/backends/vulkan/runtime/graph/ComputeGraph.h
+++ b/backends/vulkan/runtime/graph/ComputeGraph.h
@@ -63,6 +63,7 @@
DECL_VALUE_PTR_CLASS(DoubleListPtr, std::vector<double>)
DECL_VALUE_PTR_CLASS(BoolListPtr, std::vector<bool>)
DECL_VALUE_PTR_CLASS(ValueListPtr, std::vector<ValueRef>)
+DECL_VALUE_PTR_CLASS(SymIntPtr, SymInt);
#undef DECL_VALUE_PTR_CLASS
@@ -154,6 +155,7 @@
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(DoubleListPtr, double_list, DoubleList)
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(BoolListPtr, bool_list, BoolList)
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(ValueListPtr, value_list, ValueList)
+ GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(SymIntPtr, symint, SymInt);
#undef GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS
@@ -422,16 +424,29 @@
ValueRef add_string(std::string&& str);
+ ValueRef add_symint(const int32_t val);
+
ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);
template <typename Block>
- const vkapi::BufferBindInfo create_params_buffer(const Block& data) {
+ vkapi::BufferBindInfo create_params_buffer(const Block& data) {
param_ubos_.emplace_back(api::ParamsBuffer(context_.get(), data));
return vkapi::BufferBindInfo(param_ubos_.back().buffer());
}
/*
+ * Given a ValueRef, do the following depending on the type of the Value:
+ * - If it is a SymInt, return the BufferBindInfo of the ParamsBuffer object
+ * backing the SymInt.
+ * - If it is a regular Int, create a new ParamsBuffer using the integer value
+ * and return the BufferBindInfo of the created ParamsBuffer.
+ */
+ vkapi::BufferBindInfo get_or_create_int_param_buffer(const ValueRef idx);
+
+ void set_symint(const ValueRef idx, const int32_t val);
+
+ /*
* Convenience function to add an input tensor along with its staging buffer
*/
inline IOValueRef add_input_tensor(
@@ -577,6 +592,7 @@
friend class DoubleListPtr;
friend class BoolListPtr;
friend class ValueListPtr;
+ friend class SymIntPtr;
};
template <typename T>
diff --git a/backends/vulkan/runtime/graph/containers/SymInt.cpp b/backends/vulkan/runtime/graph/containers/SymInt.cpp
new file mode 100644
index 0000000..c91db84
--- /dev/null
+++ b/backends/vulkan/runtime/graph/containers/SymInt.cpp
@@ -0,0 +1,24 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <executorch/backends/vulkan/runtime/graph/containers/SymInt.h>
+
+namespace vkcompute {
+
+SymInt::SymInt(api::Context* context_p, const int32_t val)
+ : gpu_buffer(context_p, val){};
+
+void SymInt::set(const int32_t val) {
+ gpu_buffer.update(val);
+}
+
+void SymInt::operator=(const int32_t val) {
+ gpu_buffer.update(val);
+}
+
+} // namespace vkcompute
diff --git a/backends/vulkan/runtime/graph/containers/SymInt.h b/backends/vulkan/runtime/graph/containers/SymInt.h
new file mode 100644
index 0000000..0c9fbee
--- /dev/null
+++ b/backends/vulkan/runtime/graph/containers/SymInt.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include <executorch/backends/vulkan/runtime/api/Context.h>
+#include <executorch/backends/vulkan/runtime/api/containers/ParamsBuffer.h>
+
+namespace vkcompute {
+
+/*
+ * Represents a symbolic integer whose value can be variable. It is implemented
+ * as a thin wrapper around a `ParamsBuffer` object that holds the value of the
+ * integer. The `ParamsBuffer` object allows the value of the symbolic integer
+ * to be changed from the CPU and have those changes be visible to all shaders
+ * that use the symbolic integer; it also allows the value of the symbolic
+ * integer to be the result of a compute shader.
+ *
+ * Regular scalar types represented by `TypeTag::INT` cannot be used for
+ * symbolic integers because their value is assumed to be constant; therefore
+ * the `Value` instance holding the value of the scalar does not contain
+ * any reference to the GPU buffers used to pass its value into compute shaders.
+ * Therefore, updating the value of the scalar does not impact the value seen
+ * by compute shaders.
+ */
+struct SymInt final {
+ api::ParamsBuffer gpu_buffer;
+
+ explicit SymInt(api::Context* context_p, const int32_t val);
+
+ void set(const int32_t val);
+
+ void operator=(const int32_t val);
+};
+
+} // namespace vkcompute
diff --git a/backends/vulkan/runtime/graph/containers/Types.cpp b/backends/vulkan/runtime/graph/containers/Types.cpp
index c5ffc65..e7a8951 100644
--- a/backends/vulkan/runtime/graph/containers/Types.cpp
+++ b/backends/vulkan/runtime/graph/containers/Types.cpp
@@ -29,6 +29,7 @@
PRINT_CASE(BOOLLIST)
PRINT_CASE(VALUELIST)
PRINT_CASE(STRING)
+ PRINT_CASE(SYMINT)
}
return out;
}
diff --git a/backends/vulkan/runtime/graph/containers/Types.h b/backends/vulkan/runtime/graph/containers/Types.h
index 79edbd5..5840d16 100644
--- a/backends/vulkan/runtime/graph/containers/Types.h
+++ b/backends/vulkan/runtime/graph/containers/Types.h
@@ -36,6 +36,7 @@
// Special Type
VALUELIST,
STRING,
+ SYMINT,
};
std::ostream& operator<<(std::ostream& out, const TypeTag& tag);
diff --git a/backends/vulkan/runtime/graph/containers/Value.h b/backends/vulkan/runtime/graph/containers/Value.h
index 6e03bbd..50a2b5e 100644
--- a/backends/vulkan/runtime/graph/containers/Value.h
+++ b/backends/vulkan/runtime/graph/containers/Value.h
@@ -13,6 +13,7 @@
#include <executorch/backends/vulkan/runtime/api/api.h>
#include <executorch/backends/vulkan/runtime/graph/containers/Constant.h>
+#include <executorch/backends/vulkan/runtime/graph/containers/SymInt.h>
#include <executorch/backends/vulkan/runtime/graph/containers/Types.h>
namespace vkcompute {
@@ -67,6 +68,8 @@
std::string as_string;
+ SymInt as_symint;
+
Payload() : u() {}
// NOLINTNEXTLINE
~Payload(){};
@@ -123,6 +126,7 @@
TypeTag::VALUELIST, std::vector<ValueRef>, as_value_list, vector);
CASE_MOVE_MOVEABLE_TYPE(
TypeTag::STRING, std::string, as_string, basic_string);
+ CASE_MOVE_MOVEABLE_TYPE(TypeTag::SYMINT, SymInt, as_symint, SymInt);
case TypeTag::NONE:
clearToNone();
@@ -172,6 +176,9 @@
case TypeTag::STRING:
payload.as_string.~basic_string();
break;
+ case TypeTag::SYMINT:
+ payload.as_symint.~SymInt();
+ break;
// Manually list out the types so that if a type here is added later and
// not handled the compiler can catch it.
case TypeTag::NONE:
@@ -288,6 +295,8 @@
TypeTag::STRING,
as_string);
+ SUPPORT_TRIVIALLY_MOVEABLE_TYPE(SymInt, SymInt, TypeTag::SYMINT, as_symint);
+
#undef SUPPORT_TRIVIALLY_COPYABLE_TYPE
#undef SUPPORT_TRIVIALLY_MOVEABLE_TYPE
diff --git a/backends/vulkan/test/glsl/scalar_add_texture.glsl b/backends/vulkan/test/glsl/scalar_add_texture.glsl
new file mode 100644
index 0000000..aa2b22c
--- /dev/null
+++ b/backends/vulkan/test/glsl/scalar_add_texture.glsl
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#version 450 core
+
+#define PRECISION ${PRECISION}
+
+layout(std430) buffer;
+
+${layout_declare_tensor(0, "rw", "t_in", "float", "texture3d")}
+${layout_declare_ubo(1, "uvec3", "extents")}
+${layout_declare_ubo(2, "int", "scalar")}
+
+layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
+
+void main() {
+ const ivec3 pos = ivec3(gl_GlobalInvocationID);
+ if (any(greaterThanEqual(pos, extents))) {
+ return;
+ }
+
+ vec4 in_tex = imageLoad(t_in, pos);
+ imageStore(t_in, pos, imageLoad(t_in, pos) + float(scalar));
+}
diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp
index 2f9c3d2..a0bfefa 100644
--- a/backends/vulkan/test/vulkan_compute_api_test.cpp
+++ b/backends/vulkan/test/vulkan_compute_api_test.cpp
@@ -1268,6 +1268,64 @@
}
}
+TEST(VulkanComputeGraphTest, test_simple_graph_with_symint) {
+ GraphConfig config;
+ config.set_storage_type_override(utils::kTexture3D);
+ ComputeGraph graph(config);
+
+ std::vector<int64_t> sizes = {8, 64, 124};
+
+ // Build graph
+
+ ValueRef scalar = graph.add_symint(1);
+ IOValueRef a = graph.add_input_tensor(sizes, vkapi::kFloat);
+
+ IOValueRef out = {};
+ out.value = a.value;
+
+ graph.execute_nodes().emplace_back(new ExecuteNode(
+ graph,
+ VK_KERNEL_FROM_STR("scalar_add_texture"),
+ graph.create_global_wg_size(a.value),
+ graph.create_local_wg_size(a.value),
+ // Inputs and Outputs
+ {{out.value, vkapi::MemoryAccessType::WRITE}},
+ // Shader params buffers
+ {graph.texture_limits_ubo(a.value),
+ graph.get_or_create_int_param_buffer(scalar)},
+ // Specialization Constants
+ {},
+ // Resizing Logic
+ nullptr,
+ {}));
+
+ out.staging = graph.set_output_tensor(out.value);
+
+ graph.prepare();
+ graph.encode_execute();
+
+ // Run graph
+
+ for (float i = 5.0f; i < 30.0f; i += 10.0f) {
+ int scalar_val = i - 3.0f;
+ graph.set_symint(scalar, scalar_val);
+
+ float val_a = i + 2.0f;
+ float val_out = val_a + scalar_val;
+
+ fill_vtensor(graph, a, val_a);
+
+ graph.execute();
+
+ EXTRACT_TENSOR(out);
+
+ // Sanity check that the values are correct
+ for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) {
+ CHECK_VALUE(data_out, i, val_out);
+ }
+ }
+}
+
#define CREATE_WEIGHT_TENSOR(name, sizes, dtype, val) \
std::vector<float> data_##name(utils::multiply_integers(sizes)); \
std::fill(data_##name.begin(), data_##name.end(), val); \