[PyTorch][Vulkan] Add template based codegen for shader generation (#88323)

We would like to be able to parameterize kernels such that a parameterized
algorithm can be implemented via templates. We can then profile performance of
a kernel with different parameter values. This enables us to determine what
parameters may work the best for a given kernel or a given device.

In this diff one such kernel added in 1x1 conv which parameters across size of
the tile being produced by each invocation.

Few other options for parameters can be:
- One can imagine dtype can also be a parameter such that we can do compute in
fp16 or int8/int16.
- Register blocking for input channels

Differential Revision: [D40280336](https://our.internmc.facebook.com/intern/diff/D40280336/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D40280336/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88323
Approved by: https://github.com/jmdetloff
diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_pw_2x2.glsl b/aten/src/ATen/native/vulkan/glsl/templates/conv2d_pw.glslt
similarity index 78%
rename from aten/src/ATen/native/vulkan/glsl/conv2d_pw_2x2.glsl
rename to aten/src/ATen/native/vulkan/glsl/templates/conv2d_pw.glslt
index fe85f3f..1918484 100644
--- a/aten/src/ATen/native/vulkan/glsl/conv2d_pw_2x2.glsl
+++ b/aten/src/ATen/native/vulkan/glsl/templates/conv2d_pw.glslt
@@ -1,10 +1,7 @@
-#version 450 core
-#define PRECISION $precision
-#define FORMAT $format
-
 /*
- * TILE_SIZE = (2, 2, 1)
- * WEIGHT_STORAGE = TEXTURE_3D
+ * TILE_SIZE = ($TILE_SIZE_X, $TILE_SIZE_Y, 1)
+ * WEIGHT_STORAGE = TEXTURE_2D
+ * WEIGHT_STORAGE_LAYOUT = OC4,IC4,4ic,4oc
  */
 
 layout(std430) buffer;
@@ -54,17 +51,19 @@
 void main() {
   const ivec3 gpos = ivec3(gl_GlobalInvocationID);
 
-  // Determine the output positions that will be written to.
+  // Output position for TILE_SIZE_X, TILE_SIZE_Y = 2, 2
   // +--------+--------+
   // | pos[0] | pos[1] |
   // +--------+--------+
   // | pos[2] | pos[3] |
   // +--------+--------+
-  ivec3 pos[4];
-  pos[0] = ivec3(gpos.x * 2, gpos.y * 2, gpos.z);
-  pos[1] = ivec3(gpos.x * 2 + 1, gpos.y * 2, gpos.z);
-  pos[2] = ivec3(gpos.x * 2, gpos.y * 2 + 1, gpos.z);
-  pos[3] = ivec3(gpos.x * 2 + 1, gpos.y * 2 + 1, gpos.z);
+  ivec3 pos[$TILE_SIZE_X * $TILE_SIZE_Y];
+  for (int y = 0, i = 0; y < $TILE_SIZE_Y; ++y) {
+    for (int x = 0; x < $TILE_SIZE_X; ++x) {
+      pos[i] = ivec3(gpos.x * $TILE_SIZE_X + x, gpos.y * $TILE_SIZE_Y + y, gpos.z);
+      i++;
+    }
+  }
 
   // If the top left position is out of bounds, then this invocation will have
   // no work to do.
@@ -75,14 +74,14 @@
   // Compute the index of the input texture that needs to be loaded for each
   // output position. Note that negative indices can be produced indicating that
   // the top-left element is in a region added by padding.
-  ivec2 ipos[4];
-  for (int i = 0; i < 4; ++i) {
+  ivec2 ipos[$TILE_SIZE_X * $TILE_SIZE_Y];
+  for (int i = 0; i < $TILE_SIZE_X * $TILE_SIZE_Y; ++i) {
     ipos[i] = pos[i].xy * uBlock.stride - uBlock.padding;
   }
 
-  vec4 sum[4];
+  vec4 sum[$TILE_SIZE_X * $TILE_SIZE_Y];
   sum[0] = texelFetch(uBias, ivec2(gpos.z, 0), 0);
-  for (int i = 1; i < 4; ++i) {
+  for (int i = 1; i < $TILE_SIZE_X * $TILE_SIZE_Y; ++i) {
     sum[i] = sum[0];
   }
 
@@ -92,13 +91,18 @@
     // During prepacking, the weight tensor has been permuted so that the
     // channel (IC) dim is along the x axis, and the batch (OC) dim is along
     // the z axis.
+    vec4 in_tex[$TILE_SIZE_X * $TILE_SIZE_Y];
     const vec4 ktex_0 = texelFetch(uKernel, ivec2(z + 0, gpos.z), 0);
     const vec4 ktex_1 = texelFetch(uKernel, ivec2(z + 1, gpos.z), 0);
     const vec4 ktex_2 = texelFetch(uKernel, ivec2(z + 2, gpos.z), 0);
     const vec4 ktex_3 = texelFetch(uKernel, ivec2(z + 3, gpos.z), 0);
 
-    for (int i = 0; i < 4; ++i) {
-      const vec4 in_tex = texelFetch(uInput, ivec3(ipos[i], z4), 0);
+    for (int i = 0; i < $TILE_SIZE_Y * $TILE_SIZE_X; ++i) {
+      in_tex[i] = texelFetch(uInput, ivec3(ipos[i], z4), 0);
+    }
+
+    for (int i = 0; i < $TILE_SIZE_Y * $TILE_SIZE_X; ++i) {
+      // For 2x2 tile size algorithm works as follows.
       // To explain the calculations below, the contents one in_tex and the
       // group of 4 texels loaded from uKernel are shown:
       //
@@ -131,15 +135,14 @@
       //
       //  which is what is expressed in the following calculations. This is done
       //  for each output position.
-
-      sum[i] = fma(in_tex.xxxx, ktex_0, sum[i]);
-      sum[i] = fma(in_tex.yyyy, ktex_1, sum[i]);
-      sum[i] = fma(in_tex.zzzz, ktex_2, sum[i]);
-      sum[i] = fma(in_tex.wwww, ktex_3, sum[i]);
+      sum[i] = fma(in_tex[i].xxxx, ktex_0, sum[i]);
+      sum[i] = fma(in_tex[i].yyyy, ktex_1, sum[i]);
+      sum[i] = fma(in_tex[i].zzzz, ktex_2, sum[i]);
+      sum[i] = fma(in_tex[i].wwww, ktex_3, sum[i]);
     }
   }
 
-  for (int i = 0; i < 4; ++i) {
+  for (int i = 0; i < $TILE_SIZE_Y * $TILE_SIZE_X; ++i) {
     if (all(lessThan(pos[i], uBlock.out_extents.xyz))) {
       imageStore(
           uOutput,
diff --git a/aten/src/ATen/native/vulkan/glsl/templates/conv2d_pw_params.yaml b/aten/src/ATen/native/vulkan/glsl/templates/conv2d_pw_params.yaml
new file mode 100644
index 0000000..fef8f20
--- /dev/null
+++ b/aten/src/ATen/native/vulkan/glsl/templates/conv2d_pw_params.yaml
@@ -0,0 +1,7 @@
+conv2d_pw:
+  parameter_names_with_default_values:
+      TILE_SIZE_X: 2
+      TILE_SIZE_Y: 2
+  parameter_values:
+    - TILE_SIZE_X: 1
+      TILE_SIZE_Y: 1
diff --git a/tools/BUCK.bzl b/tools/BUCK.bzl
index e61ab02..6d16e8f 100644
--- a/tools/BUCK.bzl
+++ b/tools/BUCK.bzl
@@ -213,6 +213,18 @@
         base_module = "",
         deps = [
             torchgen_deps,
+            ":gen_aten_vulkan_glsl_lib",
+        ],
+    )
+
+    python_library(
+        name = "gen_aten_vulkan_glsl_lib",
+        srcs = [
+            "gen_vulkan_glsl.py",
+        ],
+        base_module = "tools",
+        deps = [
+            torchgen_deps,
         ],
     )
 
@@ -223,6 +235,20 @@
             "PUBLIC",
         ],
         deps = [
+            ":gen_aten_vulkan_glsl_lib",
+            ":gen_aten_vulkan_spv_lib",
+        ],
+    )
+
+    python_test(
+        name = "vulkan_codegen_test",
+        srcs = [
+            "test/test_vulkan_codegen.py",
+        ],
+        contacts = contacts,
+        visibility = ["PUBLIC"],
+        deps = [
+            ":gen_aten_vulkan_glsl_lib",
             ":gen_aten_vulkan_spv_lib",
         ],
     )
diff --git a/tools/gen_vulkan_glsl.py b/tools/gen_vulkan_glsl.py
new file mode 100644
index 0000000..bf6f16d
--- /dev/null
+++ b/tools/gen_vulkan_glsl.py
@@ -0,0 +1,111 @@
+import copy
+import os
+
+import yaml
+
+from torchgen.code_template import CodeTemplate
+from yaml.constructor import ConstructorError
+from yaml.nodes import MappingNode
+
+try:
+    from yaml import CLoader as Loader
+except ImportError:
+    from yaml import Loader  # type: ignore[misc]
+
+# https://gist.github.com/pypt/94d747fe5180851196eb
+class UniqueKeyLoader(Loader):
+    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
+        if not isinstance(node, MappingNode):
+            raise ConstructorError(
+                None,
+                None,
+                "expected a mapping node, but found %s" % node.id,
+                node.start_mark,
+            )
+        mapping = {}
+        for key_node, value_node in node.value:
+            key = self.construct_object(key_node, deep=deep)  # type: ignore[no-untyped-call]
+            try:
+                hash(key)
+            except TypeError:
+                raise ConstructorError(
+                    "while constructing a mapping",
+                    node.start_mark,
+                    "found unacceptable key ",
+                    key_node.start_mark,
+                )
+            # check for duplicate keys
+            if key in mapping:
+                raise ConstructorError(
+                    "while constructing a mapping",
+                    node.start_mark,
+                    "found duplicate key",
+                    key_node.start_mark,
+                )
+            value = self.construct_object(value_node, deep=deep)  # type: ignore[no-untyped-call]
+            mapping[key] = value
+        return mapping
+
+
+class GLSLGenerator(object):
+    standard_header = """
+#version 450 core
+#define PRECISION $precision
+#define FORMAT $format
+
+"""
+
+    def __init__(self):  # type: ignore[no-untyped-def]
+        self.ops_template_params = {}
+
+    def add_params_yaml(self, parameters_yaml_file):  # type: ignore[no-untyped-def]
+        all_template_params = {}
+        with open(parameters_yaml_file, "r") as f:
+            contents = yaml.load(f, Loader=UniqueKeyLoader)
+            for key in contents:
+                all_template_params[key] = contents[key]
+        self.validate_and_construct_op_params(all_template_params)  # type: ignore[no-untyped-call]
+
+    def validate_and_construct_op_params(self, all_template_params):  # type: ignore[no-untyped-def]
+        for op in all_template_params:
+            if op in self.ops_template_params:
+                raise KeyError(f"{op} params file has already been parsed")
+            op_params_default_vals = all_template_params[op][
+                "parameter_names_with_default_values"
+            ]
+            template_params_set = set(op_params_default_vals.keys())
+            self.ops_template_params[op] = []
+            self.ops_template_params[op].append(op_params_default_vals)
+            op_template_params_values = all_template_params[op]["parameter_values"]
+            for param_vals in op_template_params_values:
+                param_vals_set = set(param_vals.keys())
+                missing_keys = template_params_set - param_vals_set
+                invalid_keys = param_vals_set - template_params_set
+                if (len(invalid_keys)) > 0:
+                    raise KeyError(f"Invalid keys {invalid_keys} are found")
+                param_vals_copy = copy.deepcopy(param_vals)
+                for key in missing_keys:
+                    param_vals_copy[key] = op_params_default_vals[key]
+                self.ops_template_params[op].append(param_vals_copy)
+
+    def generate(self, glsl_template_in, out_dir):  # type: ignore[no-untyped-def]
+        glsl_template_name = os.path.basename(glsl_template_in)
+        op_name, extension_name = glsl_template_name.split(".")
+        if extension_name != "glslt":
+            raise TypeError(f"invalid file type for glsl template {extension_name}")
+        if op_name not in self.ops_template_params:
+            raise KeyError(f"{op_name} params have not been populated")
+        code_template = CodeTemplate.from_file(glsl_template_in)
+        for template_params in self.ops_template_params[op_name]:
+            content = GLSLGenerator.standard_header
+            param_vals_string = "x".join([str(i) for i in template_params.values()])
+            output_file_name = op_name + "_" + param_vals_string + ".glsl"
+            content += code_template.substitute(template_params)
+            output_file = os.path.join(out_dir, output_file_name)
+            with open(output_file, "w") as f:
+                f.write(content)
+
+
+# Remove this
+if __name__ == "__main__":
+    pass
diff --git a/tools/gen_vulkan_spv.py b/tools/gen_vulkan_spv.py
index 1d37a95..37aa97a 100644
--- a/tools/gen_vulkan_spv.py
+++ b/tools/gen_vulkan_spv.py
@@ -11,6 +11,8 @@
 from dataclasses import dataclass
 from typing import List
 
+from tools.gen_vulkan_glsl import GLSLGenerator
+
 H_NAME = "spv.h"
 CPP_NAME = "spv.cpp"
 DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"}
@@ -78,6 +80,26 @@
 
     return shader_info
 
+def genGLSLFromGLSLT(src_dir_path, tmp_dir_path):
+    template_dir_path = os.path.join(src_dir_path, "templates")
+    vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
+    parameter_yaml_files = []
+    for f in vexs:
+        if len(f) > 1:
+            parameter_yaml_files.append(f)
+    generator = GLSLGenerator()
+    for params_yaml in parameter_yaml_files:
+        generator.add_params_yaml(params_yaml)  # type: ignore[no-untyped-call]
+
+    vexs = glob.glob(os.path.join(src_dir_path, '**', '*.glslt'), recursive=True)
+    templateSrcPaths = []
+    for f in vexs:
+        if len(f) > 1:
+            templateSrcPaths.append(f)
+            templateSrcPaths.sort()
+    for glslt in templateSrcPaths:
+        generator.generate(glslt, tmp_dir_path)  # type: ignore[no-untyped-call]
+
 def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
     print("hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format(
         hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath))
@@ -88,6 +110,14 @@
         if len(f) > 1:
             templateSrcPaths.append(f)
             templateSrcPaths.sort()
+
+    # Now add glsl files that are generated from templates
+    genGLSLFromGLSLT(srcDirPath, tmpDirPath)
+    vexs = glob.glob(os.path.join(tmpDirPath, '**', '*.glsl'), recursive=True)
+    for f in vexs:
+        if len(f) > 1:
+            templateSrcPaths.append(f)
+            templateSrcPaths.sort()
     print("templateSrcPaths:{}".format(templateSrcPaths))
 
     spvPaths = {}
diff --git a/tools/test/test_vulkan_codegen.py b/tools/test/test_vulkan_codegen.py
new file mode 100644
index 0000000..26ccc66
--- /dev/null
+++ b/tools/test/test_vulkan_codegen.py
@@ -0,0 +1,100 @@
+import os
+import tempfile
+import unittest
+
+from tools.gen_vulkan_glsl import GLSLGenerator
+from yaml.constructor import ConstructorError
+
+
+class TestGLSLCodegen(unittest.TestCase):
+    def test_assert_on_duplicate_key_yaml(self) -> None:
+        yaml_with_duplicate_keys = """
+conv2d_pw:
+  parameter_names_with_default_values:
+      TILE_SIZE_X: 1
+      TILE_SIZE_Y: 1
+  parameter_values:
+    - TILE_SIZE_X: 2
+      TILE_SIZE_Y: 2
+    - TILE_SIZE_X: 2
+      TILE_SIZE_Y: 4
+    - TILE_SIZE_X: 4
+      TILE_SIZE_Y: 2
+    - TILE_SIZE_X: 4
+      TILE_SIZE_Y: 4
+conv2d_pw:
+  parameter_names_with_default_values:
+    - TILE_SIZE_X: 1
+    - TILE_SIZE_Y: 1
+  parameter_values:
+    - TILE_SIZE_X: 2
+      TILE_SIZE_Y: 2
+    - TILE_SIZE_X: 2
+      TILE_SIZE_Y: 4
+    - TILE_SIZE_X: 4
+      TILE_SIZE_Y: 2
+    - TILE_SIZE_X: 4
+      TILE_SIZE_Y: 4
+"""
+
+        generator = GLSLGenerator()  # type: ignore[no-untyped-call]
+        with tempfile.NamedTemporaryFile(mode="w") as fp:
+            fp.write(yaml_with_duplicate_keys)
+            fp.flush()
+            with self.assertRaisesRegex(
+                ConstructorError, r"while constructing a mapping"
+            ):
+                generator.add_params_yaml(fp.name)  # type: ignore[no-untyped-call]
+
+    def test_assert_keys_mismatch(self) -> None:
+        yaml_with_key_mismatch = """
+conv2d_pw:
+  parameter_names_with_default_values:
+      TILE_SIZE_X: 1
+      TILE_SIZE_Y: 1
+  parameter_values:
+    - TILE_SIZE_X: 2
+      TILE_SIZE_Z: 2
+"""
+
+        generator = GLSLGenerator()  # type: ignore[no-untyped-call]
+        with tempfile.NamedTemporaryFile(mode="w") as fp:
+            fp.write(yaml_with_key_mismatch)
+            fp.flush()
+            with self.assertRaisesRegex(KeyError, r"Invalid keys {'TILE_SIZE_Z'}"):
+                generator.add_params_yaml(fp.name)  # type: ignore[no-untyped-call]
+
+    def test_missing_key_default_val(self) -> None:
+        yaml_with_key_mismatch = """
+conv2d_pw:
+  parameter_names_with_default_values:
+      TILE_SIZE_X: 1
+      TILE_SIZE_Y: 1
+  parameter_values:
+    - TILE_SIZE_X: 2
+"""
+        file_content = """
+x = $TILE_SIZE_X + $TILE_SIZE_Y
+"""
+
+        generator = GLSLGenerator()  # type: ignore[no-untyped-call]
+        with tempfile.NamedTemporaryFile(mode="w") as fp:
+            fp.write(yaml_with_key_mismatch)
+            fp.flush()
+            generator.add_params_yaml(fp.name)  # type: ignore[no-untyped-call]
+            with tempfile.TemporaryDirectory() as tmp_dir:
+                template_file_name = os.path.join(tmp_dir, "conv2d_pw.glslt")
+                with open(template_file_name, "w") as template_file:
+                    template_file.write(file_content)
+                    template_file.flush()
+                    generator.generate(template_file.name, tmp_dir)  # type: ignore[no-untyped-call]
+                    file_name_1 = os.path.join(tmp_dir, "conv2d_pw_1x1.glsl")
+                    file_name_2 = os.path.join(tmp_dir, "conv2d_pw_2x1.glsl")
+                    self.assertTrue(os.path.exists(file_name_1))
+                    self.assertTrue(os.path.exists(file_name_2))
+                    with open(file_name_1, "r") as f:
+                        contents = f.read()
+                        self.assertTrue("1 + 1" in contents)
+                    with open(file_name_2, "r") as f:
+                        contents = f.read()
+                        self.assertTrue("2 + 1" in contents)