[Vulkan] Optimize cat operator for channel dimension (#67207)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67207

Improved performance for `cat` operator for channel dimension:
* Improved when the input tensor's channel size is a multiple of 4.
* Add new test cases to cover this scenario
* Limitation: We can't mix up using shader and `vkCmdCopyImage` at the same time. The way we create the output texture is different between two so we can't cross unless we create the output texture every time. We consider using `vkCmdCopyImage` only if all input tensors' channel is a multiple of 4.

{F673815905}

Test Plan:
**Test Conditions**
* 3 input tensors with size `{3, 40, 221, 193}`
* Number of iteration: `1,000`
* Compare `Time` column (`CPU` column is only for CPU execution time)
* Flushes resources every 1 iteration since the input tensor size is big
* running vulkan_perf_test requires a separate diff (D31906379)

**Test build on Android**
```
cd ~/fbsource
buck build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 //xplat/caffe2:pt_vulkan_perf_test_binAndroid\#android-arm64 --show-output
adb push buck-out/gen/xplat/caffe2/pt_vulkan_perf_test_binAndroid\#android-arm64 /data/local/tmp/vulkan_perf_test
adb shell "/data/local/tmp/vulkan_perf_test"
```
**Test build on Mac**
```
cd ~/fbsource
buck build //xplat/caffe2:pt_vulkan_perf_test_binAppleMac
./buck-out/gen/xplat/caffe2/pt_vulkan_perf_test_binAppleMac\#macosx-x86_64
```

**Test result on Google Pixel 5**
a) Without using `vkCmdCopyImage` for multiples of 4 in channel dimension
```
Run on (8 X 1804.8 MHz CPU s)
***WARNING*** CPU scaling is enabled, the benchmark real time measurements may be noisy and will incur extra overhead.
-------------------------------------------------------------------------------------------------------------
Benchmark (Without optimization for 4x channels)                            Time             CPU   Iterations
-------------------------------------------------------------------------------------------------------------
cat_op_channel_perf/N:3/C:40/H:221/W:193/iterations:1000/threads:1       60.4 ms         14.1 ms         1000
cat_op_channel_perf/N:3/C:20/H:221/W:193/iterations:1000/threads:1       24.1 ms        0.947 ms         1000
cat_op_channel_perf/N:3/C:39/H:221/W:193/iterations:1000/threads:1       59.6 ms         14.0 ms         1000
cat_op_channel_perf/N:3/C:4/H:221/W:193/iterations:5000/threads:1        5.98 ms        0.844 ms         5000
cat_op_channel_perf/N:3/C:3/H:221/W:193/iterations:5000/threads:1        6.02 ms        0.845 ms         5000
```
b) With using `vkCmdCopyImage` for multiples of 4 in channel dimension
```
Run on (8 X 1804.8 MHz CPU s)
***WARNING*** CPU scaling is enabled, the benchmark real time measurements may be noisy and will incur extra overhead.
-------------------------------------------------------------------------------------------------------------
Benchmark (With optimization for 4x channels)                               Time             CPU   Iterations
-------------------------------------------------------------------------------------------------------------
cat_op_channel_perf/N:3/C:40/H:221/W:193/iterations:1000/threads:1       39.3 ms         13.3 ms         1000
cat_op_channel_perf/N:3/C:20/H:221/W:193/iterations:1000/threads:1       16.4 ms         3.49 ms         1000
cat_op_channel_perf/N:3/C:39/H:221/W:193/iterations:1000/threads:1       59.7 ms         14.1 ms         1000
cat_op_channel_perf/N:3/C:4/H:221/W:193/iterations:5000/threads:1        3.93 ms        0.855 ms         5000
cat_op_channel_perf/N:3/C:3/H:221/W:193/iterations:5000/threads:1        6.14 ms        0.852 ms         5000
```
* `{3,40,221,193}`: 60.4 ms -> 39.3 ms (34.93% faster)
* `{3,20,221,193}`: 24.1 ms -> 16.4 ms (31.95% faster)
* `{3,4,221,193}`: 5.98 ms -> 3.93 ms (34.28% faster)

{F674052795}

Reviewed By: SS-JIA

Differential Revision: D31781390

fbshipit-source-id: 42179d28ae461a9e247053bc9718f6b8c6c819e5
diff --git a/aten/src/ATen/native/vulkan/api/Helper.cpp b/aten/src/ATen/native/vulkan/api/Helper.cpp
index 6811335..dba7d5a 100644
--- a/aten/src/ATen/native/vulkan/api/Helper.cpp
+++ b/aten/src/ATen/native/vulkan/api/Helper.cpp
@@ -12,17 +12,23 @@
     api::Command::Buffer& command_buffer,
     api::Resource::Image::Object& src_image,
     api::Resource::Image::Object& dst_image,
-    api::utils::uvec3 src_extents,
+    api::utils::uvec3 copy_extents,
+    api::utils::uvec3 src_offset,
     api::utils::uvec3 dst_offset) {
   VkImageCopy copy_info{};
   copy_info.srcSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
   copy_info.srcSubresource.layerCount = 1;
   copy_info.dstSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT;
   copy_info.dstSubresource.layerCount = 1;
-  copy_info.extent.width = src_extents.data[0u];
-  copy_info.extent.height = src_extents.data[1u];
-  copy_info.extent.depth = src_extents.data[2u];
+  copy_info.extent.width = copy_extents.data[0u];
+  copy_info.extent.height = copy_extents.data[1u];
+  copy_info.extent.depth = copy_extents.data[2u];
+  copy_info.srcOffset.x = src_offset.data[0u];
+  copy_info.srcOffset.y = src_offset.data[1u];
+  copy_info.srcOffset.z = src_offset.data[2u];
+  copy_info.dstOffset.x = dst_offset.data[0u];
   copy_info.dstOffset.y = dst_offset.data[1u];
+  copy_info.dstOffset.z = dst_offset.data[2u];
 
   // To use vkCmdCopyImage, the stage of src & dst image must be set to vTensor::Stage::Transfer.
   vkCmdCopyImage(
diff --git a/aten/src/ATen/native/vulkan/api/Helper.h b/aten/src/ATen/native/vulkan/api/Helper.h
index 83d88aa..60d8560 100644
--- a/aten/src/ATen/native/vulkan/api/Helper.h
+++ b/aten/src/ATen/native/vulkan/api/Helper.h
@@ -18,7 +18,8 @@
     api::Command::Buffer& command_buffer,
     api::Resource::Image::Object& src_image,
     api::Resource::Image::Object& dst_image,
-    api::utils::uvec3 src_extents,
+    api::utils::uvec3 copy_extents,
+    api::utils::uvec3 src_offset,
     api::utils::uvec3 dst_offset);
 
 } // namespace utils
diff --git a/aten/src/ATen/native/vulkan/ops/Concat.cpp b/aten/src/ATen/native/vulkan/ops/Concat.cpp
index 79bdf3a..96274b4 100644
--- a/aten/src/ATen/native/vulkan/ops/Concat.cpp
+++ b/aten/src/ATen/native/vulkan/ops/Concat.cpp
@@ -96,6 +96,61 @@
   return convert(v_output);
 }
 
+Tensor cat_feature_mult4ch(const TensorList tensors, vTensor& v_output) {
+  api::Context* const context = api::context();
+  api::Command::Pool& command_pool = context->command().pool;
+  api::Command::Buffer& command_buffer = command_pool.stream();
+
+  int64_t depth_size_allprior = 0;
+  int64_t ch_interval = 0;
+  for (const auto& tensor : tensors) {
+    ch_interval += tensor.sizes()[1];
+  }
+  const int64_t depth_interval = ch_interval / 4;
+
+  auto dst_image = v_output.image(
+    command_buffer,
+    vTensor::Stage::Transfer,
+    vTensor::Access::Write);
+  uvec3 src_offset{};
+  uvec3 dst_offset{};
+
+  for (const auto& tensor : tensors) {
+    const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan();
+    const vTensor& v_self = convert(self);
+    if C10_LIKELY(v_output.has_image() && v_self.has_image()) {
+      auto src_image = v_self.image(
+              command_buffer,
+              vTensor::Stage::Transfer);
+
+      const uint32_t depth_slice = safe_downcast<uint32_t>(tensor.sizes()[1] / 4);
+      uvec3 copy_extents {v_self.extents().data[0u],
+        v_self.extents().data[1u],
+        depth_slice};
+
+      for (int b = 0; b < tensor.sizes()[0]; ++b) {
+        src_offset.data[2u] = safe_downcast<uint32_t>(depth_slice * b);
+        dst_offset.data[2u] = depth_size_allprior + safe_downcast<uint32_t>(depth_interval * b);
+        api::helper::copy_texture_to_texture(command_buffer,
+          src_image,
+          dst_image,
+          copy_extents,
+          src_offset,
+          dst_offset);
+      }
+
+      depth_size_allprior += depth_slice;
+    }
+    else {
+      TORCH_CHECK(false, "Not implemented!");
+    }
+  }
+
+  command_pool.submit(context->gpu().queue, command_buffer);
+
+  return convert(v_output);
+}
+
 Tensor cat_width(const TensorList tensors, vTensor& v_output) {
   TORCH_CHECK(false, "Vulkan cat not implemented for width dimension!");
 }
@@ -110,6 +165,7 @@
     vTensor::Stage::Transfer,
     vTensor::Access::Write);
 
+  uvec3 src_offset{};
   uvec3 dst_offset{};
   for (const auto& tensor : tensors) {
     const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan();
@@ -123,6 +179,7 @@
         src_image,
         dst_image,
         v_self.extents(),
+        src_offset,
         dst_offset);
 
       // Increment by height
@@ -148,11 +205,16 @@
 
   at::Tensor tensor = tensors[0];
   int64_t cat_dim_size = 0;
+  bool is_mult4ch = true;
 
   for (const auto & t : tensors) {
      TORCH_INTERNAL_ASSERT(
       t.dim() == 4, "Vulkan cat expects 4 dimensional inputs");
 
+    if (t.sizes()[1] % 4 != 0) {
+      is_mult4ch = false;
+    }
+
     for (int d = 0; d < 4; ++d) {
       if (d == dim) {
         continue;
@@ -179,6 +241,9 @@
     return cat_height(tensors, v_output);
   }
   else if (dim == 1) {
+    if (is_mult4ch) {
+      return cat_feature_mult4ch(tensors, v_output);
+    }
     return cat_feature(tensors, v_output);
   }
   return cat_batch(tensors, v_output);
diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp
index 4c44830..4981b3a 100644
--- a/aten/src/ATen/test/vulkan_api_test.cpp
+++ b/aten/src/ATen/test/vulkan_api_test.cpp
@@ -1905,13 +1905,13 @@
   ASSERT_TRUE(check);
 }
 
-TEST(VulkanAPITest, cat_dim1_bat1_ch4multiple_success) {
+TEST(VulkanAPITest, cat_dim1_bat1_mult4ch_success) {
   // Guard
   if (!at::is_vulkan_available()) {
     return;
   }
 
-  // Arrange: batch=1 and channel (multiples of 4 <-> channel %4 == 0)
+  // Arrange: batch=1 and channel (a multiple of 4 <-> channel %4 == 0)
   const auto in_cpu1 = at::rand({1, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat));
   const auto in_cpu2 = at::rand({1, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat));
   const auto in_cpu3 = at::rand({1, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat));
@@ -1929,6 +1929,79 @@
   ASSERT_TRUE(check);
 }
 
+TEST(VulkanAPITest, cat_dim1_bat2_mult4ch_success) {
+  // Guard
+  if (!at::is_vulkan_available()) {
+    return;
+  }
+
+  // Arrange: batch=2 and channel (a multiple of 4 <-> channel %4 == 0)
+  const auto in_cpu1 = at::rand({2, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat));
+  const auto in_cpu2 = at::rand({2, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat));
+  const auto in_cpu3 = at::rand({2, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat));
+
+  // Act
+  const auto out_cpu = at::cat({in_cpu1, in_cpu2, in_cpu3}, 1);
+  const auto out_vulkan = at::cat({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan()}, 1); // dim=feature(channel)
+
+  // Assert
+  const auto check = almostEqual(out_cpu, out_vulkan.cpu());
+  if (!check) {
+    showRtol(out_cpu, out_vulkan.cpu());
+  }
+
+  ASSERT_TRUE(check);
+}
+
+TEST(VulkanAPITest, cat_dim1_mult4ch_mixed_success) {
+  // Guard
+  if (!at::is_vulkan_available()) {
+    return;
+  }
+
+  // Arrange: batch=1 and channel (different multiples of 4 <-> channel %4 == 0)
+  const auto in_cpu1 = at::rand({3, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat));
+  const auto in_cpu2 = at::rand({3, 8, 221, 193}, at::device(at::kCPU).dtype(at::kFloat));
+  const auto in_cpu3 = at::rand({3, 12, 221, 193}, at::device(at::kCPU).dtype(at::kFloat));
+
+  // Act
+  const auto out_cpu = at::cat({in_cpu1, in_cpu2, in_cpu3}, 1);
+  const auto out_vulkan = at::cat({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan()}, 1); // dim=feature(channel)
+
+  // Assert
+  const auto check = almostEqual(out_cpu, out_vulkan.cpu());
+  if (!check) {
+    showRtol(out_cpu, out_vulkan.cpu());
+  }
+
+  ASSERT_TRUE(check);
+}
+
+TEST(VulkanAPITest, cat_dim1_mult4ch_nonmult4ch_success) {
+  // Guard
+  if (!at::is_vulkan_available()) {
+    return;
+  }
+
+  // Arrange: batch=1 and channel (a mixed set of multiples and non-multiples of 4)
+  const auto in_cpu1 = at::rand({3, 3, 221, 193}, at::device(at::kCPU).dtype(at::kFloat));
+  const auto in_cpu2 = at::rand({3, 4, 221, 193}, at::device(at::kCPU).dtype(at::kFloat));
+  const auto in_cpu3 = at::rand({3, 7, 221, 193}, at::device(at::kCPU).dtype(at::kFloat));
+  const auto in_cpu4 = at::rand({3, 8, 221, 193}, at::device(at::kCPU).dtype(at::kFloat));
+
+  // Act
+  const auto out_cpu = at::cat({in_cpu1, in_cpu2, in_cpu3, in_cpu4}, 1);
+  const auto out_vulkan = at::cat({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan(), in_cpu4.vulkan()}, 1); // dim=feature(channel)
+
+  // Assert
+  const auto check = almostEqual(out_cpu, out_vulkan.cpu());
+  if (!check) {
+    showRtol(out_cpu, out_vulkan.cpu());
+  }
+
+  ASSERT_TRUE(check);
+}
+
 TEST(VulkanAPITest, cat_dim2_sameheight_success) {
   // Guard
   if (!at::is_vulkan_available()) {