[vulkan] convolution old prepacking via cpu-shader (#48330)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48330

Test Plan: Imported from OSS

Reviewed By: SS-JIA

Differential Revision: D25131500

Pulled By: IvanKobzarev

fbshipit-source-id: b11edb94a78f5d6283c7be1887d72a4ca624a9ab
diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp_1x.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp_1x.glsl
new file mode 100644
index 0000000..5cef89c
--- /dev/null
+++ b/aten/src/ATen/native/vulkan/glsl/conv2d_nogroup_clamp_1x.glsl
@@ -0,0 +1,66 @@
+#version 450 core
+#define PRECISION $precision
+layout(std430) buffer;
+layout(std430) uniform;
+layout(set = 0, rgba32f, binding = 0) writeonly PRECISION uniform image3D uOutput;
+layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
+layout(set = 0, binding = 2) uniform PRECISION sampler3D uKernel;
+layout(set = 0, binding = 3) readonly buffer bias {
+  vec4 data[];
+}
+uBias;
+layout(set = 0, binding = 4) uniform constBlock {
+  ivec2 padding;
+  ivec2 kernelSize;
+  ivec2 stride;
+  ivec2 dilate;
+  ivec4 outputSize;
+  ivec4 inputSize;
+  float outputMin;
+  float outputMax;
+}
+uConstBlock;
+
+#define UP_DIV(x, y) (((x) + (y)-1) / (y))
+
+layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in;
+
+void main() {
+  ivec3 pos = ivec3(gl_GlobalInvocationID);
+  if (all(lessThan(pos, uConstBlock.outputSize.xyz))) {
+    int kernelX = uConstBlock.kernelSize.x;
+    int kernelY = uConstBlock.kernelSize.y;
+    ivec3 inputSize = uConstBlock.inputSize.xyz;
+    ivec2 s0 = pos.xy * uConstBlock.stride - uConstBlock.padding;
+    int fx, fy, fz;
+    ivec2 sfxy = max(ivec2(0), (UP_DIV(-s0, uConstBlock.dilate)));
+    ivec2 efxy =
+        min(uConstBlock.kernelSize,
+            UP_DIV(uConstBlock.inputSize.xy - s0, uConstBlock.dilate));
+    vec4 color = uBias.data[pos.z];
+    int kY = pos.z;
+    int strideX = uConstBlock.stride.x;
+    for (fy = sfxy.y; fy < efxy.y; ++fy) {
+      int sy = fy * uConstBlock.dilate.y + s0.y;
+      for (fx = 0; fx < kernelX; ++fx) {
+        int kZ = fx + fy * kernelX;
+        int sx = fx * uConstBlock.dilate.x + s0.x;
+        fz = 0;
+        for (; fz < inputSize.z; ++fz) {
+          int kX = 4 * fz;
+          vec4 k0 = texelFetch(uKernel, ivec3(kX + 0, kY, kZ), 0);
+          vec4 k1 = texelFetch(uKernel, ivec3(kX + 1, kY, kZ), 0);
+          vec4 k2 = texelFetch(uKernel, ivec3(kX + 2, kY, kZ), 0);
+          vec4 k3 = texelFetch(uKernel, ivec3(kX + 3, kY, kZ), 0);
+
+          mat4 k = mat4(k0, k1, k2, k3);
+
+          color += k * texelFetch(uInput, ivec3(sx, sy, fz), 0);
+        }
+      }
+    }
+    vec4 outputMin = vec4(uConstBlock.outputMin);
+    vec4 outputMax = vec4(uConstBlock.outputMax);
+    imageStore(uOutput, ivec3(pos.x, pos.y, pos.z), clamp(color, outputMin, outputMax));
+  }
+}
diff --git a/aten/src/ATen/native/vulkan/ops/Common.h b/aten/src/ATen/native/vulkan/ops/Common.h
index 3c9b2e8..d8b9905 100644
--- a/aten/src/ATen/native/vulkan/ops/Common.h
+++ b/aten/src/ATen/native/vulkan/ops/Common.h
@@ -35,6 +35,10 @@
   };
 };
 
+struct Experimentation {
+  static constexpr bool kUseConv2dOldApi = true;
+};
+
 } // namespace ops
 } // namespace vulkan
 } // namespace native
diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp
index c549468..0a12c7b 100644
--- a/aten/src/ATen/native/vulkan/ops/Convolution.cpp
+++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp
@@ -1,4 +1,5 @@
 #include <ATen/native/vulkan/ops/Convolution.h>
+#include <ATen/native/vulkan/api/Utils.h>
 #include <ATen/native/ConvUtils.h>
 #include <ATen/native/utils/ParamUtils.h>
 #include <ATen/native/vulkan/ops/Persistent.h>
@@ -66,6 +67,91 @@
   // General
   //
 
+  if (Experimentation::kUseConv2dOldApi) {
+    const uint32_t OC = src_filter[Layout::Filter::output];
+    const uint32_t OC_4 = at::native::vulkan::api::utils::div_up(OC, 4u);
+    const uint32_t C = src_filter[Layout::Filter::input];
+    const uint32_t C_4 = at::native::vulkan::api::utils::div_up(C, 4u);
+    const uint32_t KH = src_filter[Layout::Filter::height];
+    const uint32_t KW = src_filter[Layout::Filter::width];
+
+    vTensor v_weight{
+      api::context(),
+      &pool,
+      {
+        1,
+        4 * KH * KW,
+        OC_4,
+        4 * C_4
+      },
+      weight.options(),
+    };
+
+    using Future = vTensor::Future<float, vTensor::Access::Write>;
+    Future v_weight_future = v_weight.host<float, vTensor::Access::Write>();
+    Future::Payload v_weight_payload = v_weight_future.wait();
+
+    float* const dst_weight_ptr = v_weight_payload.get();
+    memset(dst_weight_ptr, 0, v_weight.nbytes());
+
+    const float* src = src_weight_ptr;
+    float* const dst = dst_weight_ptr;
+
+    {
+      uint32_t ridx = 0;
+      const uint32_t oc_4SizeNumel = KW * KH * C_4 * 16;
+      for (uint32_t oc = 0; oc < OC; ++oc) {
+        int oc_4 = oc / 4;
+        int oc_4_i = oc % 4;
+        float* dst_oc = dst + oc_4 * oc_4SizeNumel;
+        for (uint32_t ic = 0; ic < C; ++ic) {
+          int ic_4 = ic / 4;
+          int ic_4_i = ic % 4;
+          float* dst_ic = dst_oc + ic_4 * KW * KH * 16;
+          for (uint32_t ky = 0; ky < KH; ++ky) {
+            float* dst_ky = dst_ic + ky * KW * 16;
+            for (uint32_t kx = 0; kx < KW; ++kx) {
+              float* dst_kx = dst_ky + kx * 16;
+              dst_kx[4 * ic_4_i + oc_4_i] = src[ridx++];
+            }
+          }
+        }
+      }
+
+      // shader KO4C4HW_to_image
+      float image[4 * C_4][OC_4][KH * KW][4];
+      memset(image, 0.f, 16 * C_4 * OC_4 * KH * KW * sizeof(float));
+      for (uint32_t sx = 0; sx < C_4; ++sx) {
+        for (uint32_t sy = 0; sy < OC_4; ++sy) {
+          for (uint32_t sz = 0; sz < (KH * KW); ++sz) {
+            for (uint32_t vi = 0; vi < 4; ++vi) {
+              int bufferVIdx = 4 * sx * KH * KW + 4 * sy * C_4 * KH * KW + 4 * sz;
+              image[4 * sx + 0][sy][sz][vi] = dst[4 * (bufferVIdx + 0) + vi];
+              image[4 * sx + 1][sy][sz][vi] = dst[4 * (bufferVIdx + 1) + vi];
+              image[4 * sx + 2][sy][sz][vi] = dst[4 * (bufferVIdx + 2) + vi];
+              image[4 * sx + 3][sy][sz][vi] = dst[4 * (bufferVIdx + 3) + vi];
+            }
+          }
+        }
+      }
+
+      // inverse function of nchw_to_image
+      const uint32_t W = 4 * C_4;
+      const uint32_t H = OC_4;
+      const uint32_t D = KH * KW;
+      for (uint32_t sx = 0; sx < W; ++sx) {
+        for (uint32_t sy = 0; sy < H; ++sy) {
+          for (uint32_t sz = 0; sz < D; ++sz) {
+            for (uint32_t szvi = 0; szvi < 4; ++szvi) {
+              dst_weight_ptr[W * sy + sx + (4 * sz + szvi) * W * H] = image[sx][sy][sz][szvi];
+            }
+          }
+        }
+      }
+    }
+
+    return v_weight;
+  }
 
   vTensor v_weight{
     api::context(),
@@ -624,6 +710,99 @@
   };
 }
 
+void conv2d_old(
+    api::Context* const context,
+    api::Command::Buffer& command_buffer,
+    vTensor& v_output,
+    const vTensor& v_input,
+    const vTensor& v_weight,
+    const vTensor& v_bias,
+    const IntArrayRef filter,
+    const IntArrayRef stride,
+    const IntArrayRef padding,
+    const IntArrayRef dilation,
+    const float output_min,
+    const float output_max) {
+
+  using namespace api::utils;
+
+  if (v_output.has_image() && v_input.has_image() && v_weight.has_image()) {
+    const int32_t W = v_input.extents().data[0];
+    const int32_t H = v_input.extents().data[1];
+    const int32_t C_4 = v_input.extents().data[2];
+    const int32_t C = 4 * C_4;
+
+    const int32_t OW = v_output.extents().data[0];
+    const int32_t OH = v_output.extents().data[1];
+    const int32_t OC_4 = v_output.extents().data[2];
+    const int32_t OC = 4 * OC_4;
+
+    const struct {
+      int32_t padding_x, padding_y;
+      int32_t kernel_x, kernel_y;
+      int32_t stride_x, stride_y;
+      int32_t dilate_x, dilate_y;
+      int32_t outputSize[4];
+      int32_t inputSize[4];
+      float outputMin;
+      float outputMax;
+    } block {
+      safe_downcast<int32_t>(padding[Layout::Parameter::width]),
+      safe_downcast<int32_t>(padding[Layout::Parameter::height]),
+      safe_downcast<int32_t>(filter[Layout::Filter::width]),
+      safe_downcast<int32_t>(filter[Layout::Filter::height]),
+      safe_downcast<int32_t>(stride[Layout::Parameter::width]),
+      safe_downcast<int32_t>(stride[Layout::Parameter::height]),
+      safe_downcast<int32_t>(dilation[Layout::Parameter::width]),
+      safe_downcast<int32_t>(dilation[Layout::Parameter::height]),
+      { OW, OH, OC_4, OC },
+      { W, H, C_4, C },
+      output_min,
+      output_max,
+    };
+
+    context->dispatch(
+        command_buffer,
+        {
+          VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
+          VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
+          VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
+          VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
+          VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
+        },
+        VK_KERNEL(conv2d_nogroup_clamp),
+        //VK_KERNEL(conv2d_nogroup_clamp_1x),
+        v_output.extents(),
+        // Write-only access bypasses synchronization but inserts appropriate
+        // barriers if necessary.
+        v_output.image(
+          command_buffer,
+          vTensor::Stage::Compute,
+          vTensor::Access::Write),
+        // Read-only access is implied on const tensors and triggers an async
+        // synchronization if necessary.
+        v_input.image(
+          command_buffer,
+          vTensor::Stage::Compute),
+        // Read-only access is implied on const tensors and triggers an async
+        // synchronization if necessary.
+        v_weight.image(
+          command_buffer,
+          vTensor::Stage::Compute),
+        // Read-only access is implied on const tensors and triggers an async
+        // synchronization if necessary.
+        v_bias.buffer(
+          command_buffer,
+          vTensor::Stage::Compute),
+        // Object lifetime is managed by the resource pool.
+        // It is OK not to keep track of the handle.
+        context->resource().pool.uniform(block).object);
+  }
+  else {
+    TORCH_CHECK(false, "Not implemented!");
+  }
+}
+
 Tensor Conv2dOpContext::run(const Tensor& input_arg) const {
   api::Context* const context = api::context();
 
@@ -664,34 +843,52 @@
           packed_.output_min,
           packed_.output_max);
     }
-    else if (is_pointwise(unpacked_.filter)) {
-      conv2d_pointwise(
-          context,
-          command_buffer,
-          v_output,
-          v_input,
-          packed_.v_weight,
-          packed_.v_bias,
-          packed_.filter,
-          packed_.stride,
-          packed_.padding,
-          packed_.output_min,
-          packed_.output_max);
-    }
     else {
-      conv2d(
-          context,
-          command_buffer,
-          v_output,
-          v_input,
-          packed_.v_weight,
-          packed_.v_bias,
-          packed_.filter,
-          packed_.stride,
-          packed_.padding,
-          packed_.dilation,
-          packed_.output_min,
-          packed_.output_max);
+      if (Experimentation::kUseConv2dOldApi) {
+        conv2d_old(
+            context,
+            command_buffer,
+            v_output,
+            v_input,
+            packed_.v_weight,
+            packed_.v_bias,
+            packed_.filter,
+            packed_.stride,
+            packed_.padding,
+            packed_.dilation,
+            packed_.output_min,
+            packed_.output_max);
+      } else {
+        if (is_pointwise(unpacked_.filter)) {
+          conv2d_pointwise(
+              context,
+              command_buffer,
+              v_output,
+              v_input,
+              packed_.v_weight,
+              packed_.v_bias,
+              packed_.filter,
+              packed_.stride,
+              packed_.padding,
+              packed_.output_min,
+              packed_.output_max);
+        }
+        else {
+          conv2d(
+              context,
+              command_buffer,
+              v_output,
+              v_input,
+              packed_.v_weight,
+              packed_.v_bias,
+              packed_.filter,
+              packed_.stride,
+              packed_.padding,
+              packed_.dilation,
+              packed_.output_min,
+              packed_.output_max);
+        }
+      }
     }
   }
   command_buffer.end();