[ET-VK]  Add type for symbolic integers

Differential Revision: D62144399

Pull Request resolved: https://github.com/pytorch/executorch/pull/5040
diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp
index 6c3ec88..a8f57f5 100644
--- a/backends/vulkan/runtime/graph/ComputeGraph.cpp
+++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp
@@ -43,6 +43,7 @@
 VALUE_PTR_CLASS_IMPL(DoubleListPtr, std::vector<double>, DoubleList)
 VALUE_PTR_CLASS_IMPL(BoolListPtr, std::vector<bool>, BoolList)
 VALUE_PTR_CLASS_IMPL(ValueListPtr, std::vector<ValueRef>, ValueList)
+VALUE_PTR_CLASS_IMPL(SymIntPtr, SymInt, SymInt)
 
 #undef VALUE_PTR_CLASS_IMPL
 
@@ -261,6 +262,13 @@
   return idx;
 }
 
+ValueRef ComputeGraph::add_symint(const int32_t val) {
+  ValueRef idx(static_cast<int>(values_.size()));
+  check_no_active_value_ptrs();
+  values_.emplace_back(SymInt(context(), val));
+  return idx;
+}
+
 ValueRef ComputeGraph::set_input_tensor(
     const ValueRef idx,
     const bool use_staging) {
@@ -300,6 +308,22 @@
   return idx;
 }
 
+vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
+    const ValueRef idx) {
+  if (values_.at(idx).isInt()) {
+    const int32_t val = extract_scalar<int32_t>(idx);
+    create_params_buffer(val);
+  } else if (values_.at(idx).isSymInt()) {
+    SymIntPtr symint = get_symint(idx);
+    return vkapi::BufferBindInfo(symint->gpu_buffer.buffer());
+  }
+  VK_THROW("Cannot create a int param buffer for the given value");
+}
+
+void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
+  get_symint(idx)->set(val);
+}
+
 SharedObject& ComputeGraph::get_shared_object(const int64_t idx) {
   if (idx >= shared_objects_.size()) {
     shared_objects_.resize(static_cast<size_t>(idx + 1));
diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h
index 9b04b08..ac5e0d6 100644
--- a/backends/vulkan/runtime/graph/ComputeGraph.h
+++ b/backends/vulkan/runtime/graph/ComputeGraph.h
@@ -63,6 +63,7 @@
 DECL_VALUE_PTR_CLASS(DoubleListPtr, std::vector<double>)
 DECL_VALUE_PTR_CLASS(BoolListPtr, std::vector<bool>)
 DECL_VALUE_PTR_CLASS(ValueListPtr, std::vector<ValueRef>)
+DECL_VALUE_PTR_CLASS(SymIntPtr, SymInt);
 
 #undef DECL_VALUE_PTR_CLASS
 
@@ -154,6 +155,7 @@
   GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(DoubleListPtr, double_list, DoubleList)
   GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(BoolListPtr, bool_list, BoolList)
   GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(ValueListPtr, value_list, ValueList)
+  GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(SymIntPtr, symint, SymInt);
 
 #undef GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS
 
@@ -422,16 +424,29 @@
 
   ValueRef add_string(std::string&& str);
 
+  ValueRef add_symint(const int32_t val);
+
   ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
   ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);
 
   template <typename Block>
-  const vkapi::BufferBindInfo create_params_buffer(const Block& data) {
+  vkapi::BufferBindInfo create_params_buffer(const Block& data) {
     param_ubos_.emplace_back(api::ParamsBuffer(context_.get(), data));
     return vkapi::BufferBindInfo(param_ubos_.back().buffer());
   }
 
   /*
+   * Given a ValueRef, do the following depending on the type of the Value:
+   * - If it is a SymInt, return the BufferBindInfo of the ParamsBuffer object
+   *   backing the SymInt.
+   * - If it is a regular Int, create a new ParamsBuffer using the integer value
+   *   and return the BufferBindInfo of the created ParamsBuffer.
+   */
+  vkapi::BufferBindInfo get_or_create_int_param_buffer(const ValueRef idx);
+
+  void set_symint(const ValueRef idx, const int32_t val);
+
+  /*
    * Convenience function to add an input tensor along with its staging buffer
    */
   inline IOValueRef add_input_tensor(
@@ -577,6 +592,7 @@
   friend class DoubleListPtr;
   friend class BoolListPtr;
   friend class ValueListPtr;
+  friend class SymIntPtr;
 };
 
 template <typename T>
diff --git a/backends/vulkan/runtime/graph/containers/SymInt.cpp b/backends/vulkan/runtime/graph/containers/SymInt.cpp
new file mode 100644
index 0000000..c91db84
--- /dev/null
+++ b/backends/vulkan/runtime/graph/containers/SymInt.cpp
@@ -0,0 +1,24 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <executorch/backends/vulkan/runtime/graph/containers/SymInt.h>
+
+namespace vkcompute {
+
+SymInt::SymInt(api::Context* context_p, const int32_t val)
+    : gpu_buffer(context_p, val){};
+
+void SymInt::set(const int32_t val) {
+  gpu_buffer.update(val);
+}
+
+void SymInt::operator=(const int32_t val) {
+  gpu_buffer.update(val);
+}
+
+} // namespace vkcompute
diff --git a/backends/vulkan/runtime/graph/containers/SymInt.h b/backends/vulkan/runtime/graph/containers/SymInt.h
new file mode 100644
index 0000000..0c9fbee
--- /dev/null
+++ b/backends/vulkan/runtime/graph/containers/SymInt.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include <executorch/backends/vulkan/runtime/api/Context.h>
+#include <executorch/backends/vulkan/runtime/api/containers/ParamsBuffer.h>
+
+namespace vkcompute {
+
+/*
+ * Represents a symbolic integer whose value can be variable. It is implemented
+ * as a thin wrapper around a `ParamsBuffer` object that holds the value of the
+ * integer. The `ParamsBuffer` object allows the value of the symbolic integer
+ * to be changed from the CPU and have those changes be visible to all shaders
+ * that use the symbolic integer; it also allows the value of the symbolic
+ * integer to be the result of a compute shader.
+ *
+ * Regular scalar types represented by `TypeTag::INT` cannot be used for
+ * symbolic integers because their value is assumed to be constant; therefore
+ * the `Value` instance holding the value of the scalar does not contain
+ * any reference to the GPU buffers used to pass its value into compute shaders.
+ * Therefore, updating the value of the scalar does not impact the value seen
+ * by compute shaders.
+ */
+struct SymInt final {
+  api::ParamsBuffer gpu_buffer;
+
+  explicit SymInt(api::Context* context_p, const int32_t val);
+
+  void set(const int32_t val);
+
+  void operator=(const int32_t val);
+};
+
+} // namespace vkcompute
diff --git a/backends/vulkan/runtime/graph/containers/Types.cpp b/backends/vulkan/runtime/graph/containers/Types.cpp
index c5ffc65..e7a8951 100644
--- a/backends/vulkan/runtime/graph/containers/Types.cpp
+++ b/backends/vulkan/runtime/graph/containers/Types.cpp
@@ -29,6 +29,7 @@
     PRINT_CASE(BOOLLIST)
     PRINT_CASE(VALUELIST)
     PRINT_CASE(STRING)
+    PRINT_CASE(SYMINT)
   }
   return out;
 }
diff --git a/backends/vulkan/runtime/graph/containers/Types.h b/backends/vulkan/runtime/graph/containers/Types.h
index 79edbd5..5840d16 100644
--- a/backends/vulkan/runtime/graph/containers/Types.h
+++ b/backends/vulkan/runtime/graph/containers/Types.h
@@ -36,6 +36,7 @@
   // Special Type
   VALUELIST,
   STRING,
+  SYMINT,
 };
 
 std::ostream& operator<<(std::ostream& out, const TypeTag& tag);
diff --git a/backends/vulkan/runtime/graph/containers/Value.h b/backends/vulkan/runtime/graph/containers/Value.h
index 6e03bbd..50a2b5e 100644
--- a/backends/vulkan/runtime/graph/containers/Value.h
+++ b/backends/vulkan/runtime/graph/containers/Value.h
@@ -13,6 +13,7 @@
 #include <executorch/backends/vulkan/runtime/api/api.h>
 
 #include <executorch/backends/vulkan/runtime/graph/containers/Constant.h>
+#include <executorch/backends/vulkan/runtime/graph/containers/SymInt.h>
 #include <executorch/backends/vulkan/runtime/graph/containers/Types.h>
 
 namespace vkcompute {
@@ -67,6 +68,8 @@
 
     std::string as_string;
 
+    SymInt as_symint;
+
     Payload() : u() {}
     // NOLINTNEXTLINE
     ~Payload(){};
@@ -123,6 +126,7 @@
           TypeTag::VALUELIST, std::vector<ValueRef>, as_value_list, vector);
       CASE_MOVE_MOVEABLE_TYPE(
           TypeTag::STRING, std::string, as_string, basic_string);
+      CASE_MOVE_MOVEABLE_TYPE(TypeTag::SYMINT, SymInt, as_symint, SymInt);
 
       case TypeTag::NONE:
         clearToNone();
@@ -172,6 +176,9 @@
       case TypeTag::STRING:
         payload.as_string.~basic_string();
         break;
+      case TypeTag::SYMINT:
+        payload.as_symint.~SymInt();
+        break;
       // Manually list out the types so that if a type here is added later and
       // not handled the compiler can catch it.
       case TypeTag::NONE:
@@ -288,6 +295,8 @@
       TypeTag::STRING,
       as_string);
 
+  SUPPORT_TRIVIALLY_MOVEABLE_TYPE(SymInt, SymInt, TypeTag::SYMINT, as_symint);
+
 #undef SUPPORT_TRIVIALLY_COPYABLE_TYPE
 #undef SUPPORT_TRIVIALLY_MOVEABLE_TYPE
 
diff --git a/backends/vulkan/test/glsl/scalar_add_texture.glsl b/backends/vulkan/test/glsl/scalar_add_texture.glsl
new file mode 100644
index 0000000..aa2b22c
--- /dev/null
+++ b/backends/vulkan/test/glsl/scalar_add_texture.glsl
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#version 450 core
+
+#define PRECISION ${PRECISION}
+
+layout(std430) buffer;
+
+${layout_declare_tensor(0, "rw", "t_in", "float", "texture3d")}
+${layout_declare_ubo(1, "uvec3", "extents")}
+${layout_declare_ubo(2, "int", "scalar")}
+
+layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
+
+void main() {
+  const ivec3 pos = ivec3(gl_GlobalInvocationID);
+  if (any(greaterThanEqual(pos, extents))) {
+    return;
+  }
+
+  vec4 in_tex = imageLoad(t_in, pos);
+  imageStore(t_in, pos, imageLoad(t_in, pos) + float(scalar));
+}
diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp
index 2f9c3d2..a0bfefa 100644
--- a/backends/vulkan/test/vulkan_compute_api_test.cpp
+++ b/backends/vulkan/test/vulkan_compute_api_test.cpp
@@ -1268,6 +1268,64 @@
   }
 }
 
+TEST(VulkanComputeGraphTest, test_simple_graph_with_symint) {
+  GraphConfig config;
+  config.set_storage_type_override(utils::kTexture3D);
+  ComputeGraph graph(config);
+
+  std::vector<int64_t> sizes = {8, 64, 124};
+
+  // Build graph
+
+  ValueRef scalar = graph.add_symint(1);
+  IOValueRef a = graph.add_input_tensor(sizes, vkapi::kFloat);
+
+  IOValueRef out = {};
+  out.value = a.value;
+
+  graph.execute_nodes().emplace_back(new ExecuteNode(
+      graph,
+      VK_KERNEL_FROM_STR("scalar_add_texture"),
+      graph.create_global_wg_size(a.value),
+      graph.create_local_wg_size(a.value),
+      // Inputs and Outputs
+      {{out.value, vkapi::MemoryAccessType::WRITE}},
+      // Shader params buffers
+      {graph.texture_limits_ubo(a.value),
+       graph.get_or_create_int_param_buffer(scalar)},
+      // Specialization Constants
+      {},
+      // Resizing Logic
+      nullptr,
+      {}));
+
+  out.staging = graph.set_output_tensor(out.value);
+
+  graph.prepare();
+  graph.encode_execute();
+
+  // Run graph
+
+  for (float i = 5.0f; i < 30.0f; i += 10.0f) {
+    int scalar_val = i - 3.0f;
+    graph.set_symint(scalar, scalar_val);
+
+    float val_a = i + 2.0f;
+    float val_out = val_a + scalar_val;
+
+    fill_vtensor(graph, a, val_a);
+
+    graph.execute();
+
+    EXTRACT_TENSOR(out);
+
+    // Sanity check that the values are correct
+    for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) {
+      CHECK_VALUE(data_out, i, val_out);
+    }
+  }
+}
+
 #define CREATE_WEIGHT_TENSOR(name, sizes, dtype, val)              \
   std::vector<float> data_##name(utils::multiply_integers(sizes)); \
   std::fill(data_##name.begin(), data_##name.end(), val);          \