Use dequantized weight and bias in conv2d quantized ops (#115615)

Summary:
Dequantize weight and bias for conv2d ops to improve performance. The weight and bias are usually small in size hence they do not increase memory footprint by a lot when dequantized.

With optimization cunet-enc ops:
vulkan.quantized_conv2d  {96, 72, 2}                      3753204
vulkan.quantized_conv2d  {96, 72, 2}                      6977048
vulkan.quantized_conv2d_dw{96, 72, 2}                      2499640
vulkan.quantized_conv2d_pw_2x2{96, 72, 2}                       842088
vulkan.quantized_conv2d  {48, 36, 4}                      2388152
vulkan.quantized_conv2d  {48, 36, 4}                      4775940
vulkan.quantized_conv2d_dw{48, 36, 4}                       709800
vulkan.quantized_conv2d_pw_2x2{48, 36, 4}                       483236
vulkan.quantized_conv2d  {24, 18, 8}                      2562144
vulkan.quantized_conv2d  {24, 18, 8}                      5447624
vulkan.quantized_conv2d_dw{24, 18, 8}                       392756
vulkan.quantized_conv2d_pw_2x2{24, 18, 8}                       509080

Without optimization:
vulkan.quantized_conv2d  {96, 72, 2}                      4291768
vulkan.quantized_conv2d  {96, 72, 2}                      7871344
vulkan.quantized_conv2d_dw{96, 72, 2}                      2658500
vulkan.quantized_conv2d_pw_2x2{96, 72, 2}                       891020
vulkan.quantized_conv2d  {48, 36, 4}                      2966860
vulkan.quantized_conv2d  {48, 36, 4}                      5661812
vulkan.quantized_conv2d_dw{48, 36, 4}                       816556
vulkan.quantized_conv2d_pw_2x2{48, 36, 4}                       528632
vulkan.quantized_conv2d  {24, 18, 8}                      3139604
vulkan.quantized_conv2d  {24, 18, 8}                      6202820
vulkan.quantized_conv2d_dw{24, 18, 8}                       452660
vulkan.quantized_conv2d_pw_2x2{24, 18, 8}                       557388

Test Plan:
Ensure all vulkan quantize tests pass:
buck2 run --target-platforms ovr_configplatform/macos:arm64-fbsourcexplat/caffe2:pt_vulkan_quantized_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 --show-output"
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
[==========] Running 78 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 78 tests from VulkanAPITest

...
[==========] 78 tests from 1 test suite ran. (1519 ms total)
[  PASSED  ] 78 tests.

buck2 run --target-platforms ovr_config//platform/macos:arm64-fbsource  //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 --show-output"

Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
[==========] Running 395 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 395 tests from VulkanAPITest

...
[----------] 395 tests from VulkanAPITest (6515 ms total)

[----------] Global test environment tear-down
[==========] 395 tests from 1 test suite ran. (6515 ms total)
[  PASSED  ] 394 tests.
[  SKIPPED ] 1 test, listed below:
[  SKIPPED ] VulkanAPITest.querypool_flushed_shader_log

  YOU HAVE 5 DISABLED TESTS

Reviewed By: yipjustin

Differential Revision: D50997532

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115615
Approved by: https://github.com/manuelcandales, https://github.com/yipjustin
diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl
index f2915cf..df905a8 100644
--- a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl
+++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl
@@ -21,8 +21,8 @@
  * Input Textures
  */
 layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput;
-layout(set = 0, binding = 2) uniform PRECISION isampler2D uKernel;
-layout(set = 0, binding = 3) uniform PRECISION isampler2D uBias;
+layout(set = 0, binding = 2) uniform PRECISION sampler2D uKernel;
+layout(set = 0, binding = 3) uniform PRECISION sampler2D uBias;
 
 /*
  * Params Buffer
@@ -102,10 +102,7 @@
   kstart.x *= 4;
   kstart.y += pos.z * uBlock.kernel_size.y;
 
-  vec4 sum = dequantize(
-      texelFetch(uBias, ivec2(pos.z, 0), 0),
-      uBlock.scales.w,
-      uBlock.zero_points.w);
+  vec4 sum = texelFetch(uBias, ivec2(pos.z, 0), 0);
 
   // Perform the convolution by iterating over the overlay region
   const int dil_y = uBlock.dilate.y;
@@ -152,28 +149,16 @@
         //
         //  which is what is expressed in the following calculations.
 
-        const vec4 ktex_0 = dequantize(
-            texelFetch(uKernel, ivec2(kx + 0, ky), 0),
-            uBlock.scales.z,
-            uBlock.zero_points.z);
+        const vec4 ktex_0 = texelFetch(uKernel, ivec2(kx + 0, ky), 0);
         sum = fma(in_tex.xxxx, ktex_0, sum);
 
-        const vec4 ktex_1 = dequantize(
-            texelFetch(uKernel, ivec2(kx + 1, ky), 0),
-            uBlock.scales.z,
-            uBlock.zero_points.z);
+        const vec4 ktex_1 = texelFetch(uKernel, ivec2(kx + 1, ky), 0);
         sum = fma(in_tex.yyyy, ktex_1, sum);
 
-        const vec4 ktex_2 = dequantize(
-            texelFetch(uKernel, ivec2(kx + 2, ky), 0),
-            uBlock.scales.z,
-            uBlock.zero_points.z);
+        const vec4 ktex_2 = texelFetch(uKernel, ivec2(kx + 2, ky), 0);
         sum = fma(in_tex.zzzz, ktex_2, sum);
 
-        const vec4 ktex_3 = dequantize(
-            texelFetch(uKernel, ivec2(kx + 3, ky), 0),
-            uBlock.scales.z,
-            uBlock.zero_points.z);
+        const vec4 ktex_3 = texelFetch(uKernel, ivec2(kx + 3, ky), 0);
         sum = fma(in_tex.wwww, ktex_3, sum);
       }
     }
diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl
index 817ba4f..eb4d910 100644
--- a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl
+++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl
@@ -22,8 +22,8 @@
  * Input Textures
  */
 layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput;
-layout(set = 0, binding = 2) uniform PRECISION isampler2D uKernel;
-layout(set = 0, binding = 3) uniform PRECISION isampler2D uBias;
+layout(set = 0, binding = 2) uniform PRECISION sampler2D uKernel;
+layout(set = 0, binding = 3) uniform PRECISION sampler2D uBias;
 
 /*
  * Params Buffer
@@ -89,10 +89,7 @@
   // reading the input
   const ivec2 kstart = (start - ipos) / uBlock.dilate;
 
-  vec4 sum = dequantize(
-      texelFetch(uBias, ivec2(pos.z, 0), 0),
-      uBlock.scales.w,
-      uBlock.zero_points.w);
+  vec4 sum = texelFetch(uBias, ivec2(pos.z, 0), 0);
 
   const int dil_y = uBlock.dilate.y;
   const int dil_x = uBlock.dilate.x;
@@ -103,10 +100,7 @@
       // other vertically.
       const int k_ind = kx + ky * uBlock.kernel_size.x;
 
-      const vec4 k_tex = dequantize(
-          texelFetch(uKernel, ivec2(k_ind, pos.z), 0),
-          uBlock.scales.z,
-          uBlock.zero_points.z);
+      const vec4 k_tex = texelFetch(uKernel, ivec2(k_ind, pos.z), 0);
       const vec4 in_tex = dequantize(
           texelFetch(uInput, ivec3(x, y, pos.z), 0),
           uBlock.scales.y,
diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl
index ac325e2..5b49759 100644
--- a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl
+++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl
@@ -17,8 +17,8 @@
  * Input Textures
  */
 layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput;
-layout(set = 0, binding = 2) uniform PRECISION isampler2D uKernel;
-layout(set = 0, binding = 3) uniform PRECISION isampler2D uBias;
+layout(set = 0, binding = 2) uniform PRECISION sampler2D uKernel;
+layout(set = 0, binding = 3) uniform PRECISION sampler2D uBias;
 
 /*
  * Params Buffer
@@ -99,10 +99,7 @@
   }
 
   vec4 sum[4];
-  sum[0] = dequantize(
-      texelFetch(uBias, ivec2(gpos.z, 0), 0),
-      uBlock.scales.w,
-      uBlock.zero_points.w);
+  sum[0] = texelFetch(uBias, ivec2(gpos.z, 0), 0);
   for (int i = 1; i < 4; ++i) {
     sum[i] = sum[0];
   }
@@ -113,22 +110,10 @@
     // During prepacking, the weight tensor has been permuted so that the
     // channel (IC) dim is along the x axis, and the batch (OC) dim is along
     // the z axis.
-    const vec4 ktex_0 = dequantize(
-        texelFetch(uKernel, ivec2(z + 0, gpos.z), 0),
-        uBlock.scales.z,
-        uBlock.zero_points.z);
-    const vec4 ktex_1 = dequantize(
-        texelFetch(uKernel, ivec2(z + 1, gpos.z), 0),
-        uBlock.scales.z,
-        uBlock.zero_points.z);
-    const vec4 ktex_2 = dequantize(
-        texelFetch(uKernel, ivec2(z + 2, gpos.z), 0),
-        uBlock.scales.z,
-        uBlock.zero_points.z);
-    const vec4 ktex_3 = dequantize(
-        texelFetch(uKernel, ivec2(z + 3, gpos.z), 0),
-        uBlock.scales.z,
-        uBlock.zero_points.z);
+    const vec4 ktex_0 = texelFetch(uKernel, ivec2(z + 0, gpos.z), 0);
+    const vec4 ktex_1 = texelFetch(uKernel, ivec2(z + 1, gpos.z), 0);
+    const vec4 ktex_2 = texelFetch(uKernel, ivec2(z + 2, gpos.z), 0);
+    const vec4 ktex_3 = texelFetch(uKernel, ivec2(z + 3, gpos.z), 0);
 
     vec4 in_tex[4];
     for (int i = 0; i < 4; ++i) {
diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp
index d278a32..1f4184d 100644
--- a/aten/src/ATen/native/vulkan/ops/Convolution.cpp
+++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp
@@ -13,6 +13,7 @@
 #ifndef AT_PER_OPERATOR_HEADERS
 #include <ATen/Functions.h>
 #else
+#include <ATen/ops/dequantize.h>
 #include <ATen/ops/pad.h>
 #include <ATen/ops/permute.h>
 #include <ATen/ops/quantize_per_tensor.h>
@@ -506,14 +507,16 @@
 using namespace api::utils;
 
 vTensor pack_weights(
-    const Tensor& weight_arg,
+    const Tensor& weight_inp,
     const bool transposed,
     const bool quantized,
     const Conv2dMethod conv_method) {
-  if (weight_arg.is_vulkan()) {
-    return convert(weight_arg);
+  if (weight_inp.is_vulkan()) {
+    return convert(weight_inp);
   }
 
+  const Tensor weight_arg = quantized ? at::dequantize(weight_inp) : weight_inp;
+
   const Tensor weight = transposed
       ? at::permute(weight_arg, {1, 0, 2, 3}).contiguous()
       : weight_arg.contiguous();
@@ -532,12 +535,6 @@
       api::StorageType::TEXTURE_2D,
   };
 
-  if (quantized) {
-    v_weight.set_is_quantized();
-    v_weight.set_scale(weight_arg.q_scale());
-    v_weight.set_zero_point(weight_arg.q_zero_point());
-  }
-
   pack_cpu_to_vulkan(weight_rearranged, v_weight);
 
   return v_weight;
@@ -549,8 +546,11 @@
     const bool transposed,
     const bool quantized) {
   at::Tensor bias_arg = conv2d::rearrange_bias(bias, weight, transposed);
-  at::Tensor bias_rearranged = (quantized && bias_arg.scalar_type() == kFloat)
-      ? at::quantize_per_tensor(bias_arg, weight.q_scale(), 0, c10::kQInt32)
+  at::Tensor bias_rearranged =
+      (quantized &&
+       (bias_arg.scalar_type() == kQUInt8 || bias_arg.scalar_type() == kQInt8 ||
+        bias_arg.scalar_type() == kQInt32))
+      ? at::dequantize(bias_arg)
       : bias_arg;
 
   vTensor v_bias{
@@ -560,12 +560,6 @@
       api::StorageType::TEXTURE_2D,
   };
 
-  if (quantized) {
-    v_bias.set_is_quantized();
-    v_bias.set_scale(bias_rearranged.q_scale());
-    v_bias.set_zero_point(bias_rearranged.q_zero_point());
-  }
-
   pack_cpu_to_vulkan(bias_rearranged, v_bias);
 
   return v_bias;