[iOS GPU][Kernel] Implement transpose in Metal shaders (#54522)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54522

Implement the transpose operator in metal shaders using textures.
ghstack-source-id: 126802125

Test Plan:
- Metal operator tests
```
2021-03-22 02:25:53.941006-0700 PyTorchPlayground[57924:9047121] [bool test_transpose()],[1 2 2 5 ],[SUCCEED]
2021-03-22 02:25:53.949834-0700 PyTorchPlayground[57924:9047121] [bool test_transpose2()],[1 2 58 28 28 ],[SUCCEED]
2021-03-22 03:12:19.786584-0700 PyTorchPlayground[58230:9066223] [bool test_transpose3()],[4 5 6 ],[SUCCEED]
```
- Sandcastle CI
- CircleCI

Reviewed By: SS-JIA

Differential Revision: D27225940

fbshipit-source-id: 14bfb96435a39aecf4f14bc5e2f7232421014328
diff --git a/aten/src/ATen/native/metal/MetalShaders.h b/aten/src/ATen/native/metal/MetalShaders.h
index c275038..dd87d20 100644
--- a/aten/src/ATen/native/metal/MetalShaders.h
+++ b/aten/src/ATen/native/metal/MetalShaders.h
@@ -590,7 +590,7 @@
         // we compute the "linear index" of the output element,
         // and convert it to the equivalent "linear index" of the input element.
         ushort offset = 4 * s2 + idx;
-        ushort linear_idx = n2 * C2 * H2 * W2 + offset * H2 * W2 + gid.y * W2 + gid.x;
+        int64_t linear_idx = n2 * C2 * H2 * W2 + offset * H2 * W2 + gid.y * W2 + gid.x;
         if(linear_idx >= numel1){
             value[idx] = 0;
             continue;
@@ -615,6 +615,98 @@
     }
 }
 
+constant bool transpose_in_is_arr = (ushort_arg_3 > 1 || ushort_arg_4 > 4);
+constant bool transpose_in_is_tex = !transpose_in_is_arr;
+constant bool transpose_out_is_arr = (ushort_arg_5 > 1 || ushort_arg_6 > 4);
+constant bool transpose_out_is_tex = !transpose_out_is_arr;
+kernel void transpose(texture2d_array<half, access::read>in_arr[[texture(0),function_constant(transpose_in_is_arr)]],
+                      texture2d<half, access::read> in_tex[[texture(0), function_constant(transpose_in_is_tex)]],
+                      texture2d_array<half, access::write>out_arr[[texture(1),function_constant(transpose_out_is_arr)]],
+                      texture2d<half, access::write> out_tex[[texture(1), function_constant(transpose_out_is_tex)]],
+                      constant ushort* inSizeBuffer [[buffer(0)]],
+                      constant ushort* outSizeBuffer [[buffer(1)]],
+                      device ushort* indexBuffer [[buffer(2)]],
+                      ushort3 gid[[thread_position_in_grid]]) {
+
+    const ushort dim0 = ushort_arg_0;
+    const ushort dim1 = ushort_arg_1;
+    const ushort dim = ushort_arg_2;
+    const ushort N1 = ushort_arg_3;
+    const ushort C1 = ushort_arg_4;
+    const ushort N2 = ushort_arg_5;
+    const ushort C2 = ushort_arg_6;
+    ushort W1,W2,H1,H2;
+    if(transpose_in_is_arr) {
+        W1 = in_arr.get_width();
+        H1 = in_arr.get_height();
+    } else {
+        W1 = in_tex.get_width();
+        H1 = in_tex.get_height();
+    }
+    if(transpose_out_is_arr) {
+        W2 = out_arr.get_width();
+        H2 = out_arr.get_height();
+    } else {
+        W2 = out_tex.get_width();
+        H2 = out_tex.get_height();
+    }
+    if (gid.x >= W2 || gid.y >= H2) {
+        return;
+    }
+    const int numel = H2 * W2 * C2 * N2;
+    const ushort slices2 = divRoundUp(C2, 4);
+    const ushort slices1 = divRoundUp(C1, 4);
+    const ushort n2 = gid.z / slices2;
+    const ushort s2 = gid.z - n2 * slices2;
+    half4 value;
+    for (int idx = 0; idx < 4; ++idx){
+        ushort offset = 4 * s2 + idx;
+        int64_t linear_idx2 = n2 * C2 * H2 * W2 + offset * H2 * W2 + gid.y * W2 + gid.x;
+        if(linear_idx2 >= numel) {
+            value[idx] = 0;
+            continue;
+        }
+
+        ushort d2 = 0;
+        for(int j = dim-1; j>=0; --j){
+            d2  = outSizeBuffer[j];
+            indexBuffer[j] = linear_idx2 % d2;
+            linear_idx2 /= d2;
+        }
+
+        // swap dims
+        ushort tmp = indexBuffer[dim0];
+        indexBuffer[dim0] = indexBuffer[dim1];
+        indexBuffer[dim1] = tmp;
+
+        int64_t linear_idx1 = 0;
+        ushort m = 1;
+        ushort d1 = 0;
+        for(int k = dim-1; k>=0; --k) {
+            d1 = indexBuffer[k];
+            linear_idx1 += d1 * m;
+            m *= inSizeBuffer[k];
+        }
+
+        auto x1 = linear_idx1 % W1;
+        auto y1 = ((int)(linear_idx1/W1)) % H1;
+        auto c1 = ((int)(linear_idx1/W1/H1) % C1);
+        auto n1 = ((int)(linear_idx1/W1/H1/C1) % N1);
+        auto z1 = (int)c1 / 4 + n1 * slices1;
+        auto pos = c1 % 4;
+        if(transpose_in_is_arr) {
+            value[idx] = in_arr.read(ushort2(x1, y1), z1)[pos];
+        } else {
+            value[idx] = in_tex.read(ushort2(x1, y1))[pos];
+        }
+    }
+    if(transpose_out_is_arr) {
+        out_arr.write(value, gid.xy, gid.z);
+    } else {
+        out_tex.write(value, gid.xy);
+    }
+}
+
 )PT_METAL_SHADERS";
 
 #endif /* MPSCNNShaders_h */
diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h
index 0dc7be4..85d796a 100644
--- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h
+++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h
@@ -24,6 +24,9 @@
 bool test_div_broadcast();
 bool test_div_broadcast2();
 bool test_t();
+bool test_transpose();
+bool test_transpose2();
+bool test_transpose3();
 bool test_view();
 bool test_view2();
 bool test_view3();
diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm
index aa18944..5951276 100644
--- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm
+++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm
@@ -4,7 +4,6 @@
 #import <ATen/native/metal/mpscnn/MPSImageUtils.h>
 #import <ATen/native/metal/mpscnn/tests/MPSCNNTests.h>
 #import <ATen/native/metal/ops/MetalConvolution.h>
-#import <ATen/native/metal/ops/MetalTranspose.h>
 
 #import <Foundation/Foundation.h>
 #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
@@ -490,7 +489,7 @@
       auto X1 = at::rand({H, W}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
       auto Y1 = at::t(X1).contiguous();
       auto X2 = X1.metal();
-      auto Y2 = at::native::metal::t(X2).cpu();
+      auto Y2 = at::t(X2).cpu();
       return almostEqual(Y1, Y2);
     });
     if (!b) {
@@ -500,6 +499,39 @@
   return result;
 }
 
+bool test_transpose() {
+    __block std::vector<int64_t> size {1, 2, 2, 5};
+    return TEST(size, __PRETTY_FUNCTION__, ^bool{
+        auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
+        auto Y1 = at::transpose(X1, 1, 3).contiguous();
+        auto X2 = X1.metal();
+        auto Y2 = at::transpose(X2, 1, 3).cpu();
+        return almostEqual(Y1, Y2);
+    });
+}
+
+bool test_transpose2() {
+    __block std::vector<int64_t> size {1, 2, 58, 28, 28};
+    return TEST(size, __PRETTY_FUNCTION__, ^bool{
+        auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
+        auto Y1 = at::transpose(X1, 1, 2).contiguous();
+        auto X2 = X1.metal();
+        auto Y2 = at::transpose(X2, 1, 2).cpu();
+        return almostEqual(Y1, Y2);
+    });
+}
+
+bool test_transpose3() {
+    __block std::vector<int64_t> size {4, 5, 6};
+    return TEST(size, __PRETTY_FUNCTION__, ^bool{
+        auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
+        auto Y1 = at::transpose(X1, 2, 0).contiguous();
+        auto X2 = X1.metal();
+        auto Y2 = at::transpose(X2, 2, 0).cpu();
+        return almostEqual(Y1, Y2);
+    });
+}
+
 bool test_view() {
   // array -> array
   __block std::vector<int64_t> size{1, 10, 2, 2};
diff --git a/aten/src/ATen/native/metal/ops/MetalTranspose.h b/aten/src/ATen/native/metal/ops/MetalTranspose.h
deleted file mode 100644
index 9353b6b..0000000
--- a/aten/src/ATen/native/metal/ops/MetalTranspose.h
+++ /dev/null
@@ -1,17 +0,0 @@
-#ifndef MetalCopy_h
-#define MetalCopy_h
-
-#include <ATen/Tensor.h>
-
-namespace at {
-namespace native {
-namespace metal {
-
-// TODO: Remove the header once we are able to call it through dispatcher
-Tensor t(const Tensor& input);
-
-} // namespace metal
-} // namespace native
-} // namespace at
-
-#endif
diff --git a/aten/src/ATen/native/metal/ops/MetalTranspose.mm b/aten/src/ATen/native/metal/ops/MetalTranspose.mm
index e174609..f2b791a 100644
--- a/aten/src/ATen/native/metal/ops/MetalTranspose.mm
+++ b/aten/src/ATen/native/metal/ops/MetalTranspose.mm
@@ -3,37 +3,94 @@
 #import <ATen/native/metal/MetalTensorImplStorage.h>
 #import <ATen/native/metal/MetalUtils.h>
 #import <ATen/native/metal/mpscnn/MPSCNNContext.h>
+#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
 #import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
 #import <ATen/native/metal/mpscnn/MPSImageUtils.h>
+
+#include <ATen/ATen.h>
 #include <torch/library.h>
 
 namespace at {
 namespace native {
 namespace metal {
 
+Tensor transpose(const Tensor& input, int64_t dim0, int64_t dim1) {
+  TORCH_CHECK(input.is_metal());
+  auto ndims = input.dim();
+  dim0 = maybe_wrap_dim(dim0, ndims);
+  dim1 = maybe_wrap_dim(dim1, ndims);
+  if (dim0 == dim1) {
+    return input;
+  }
+  auto outputSizes = input.sizes().vec();
+  std::swap(outputSizes[dim0], outputSizes[dim1]);
+  MPSImage* X = imageFromTensor(input);
+  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+  if (input.dim() == 2) {
+    MetalTensorImplStorage mt{outputSizes};
+    mt.texture()->allocateTemporaryTextureStorage(outputSizes, commandBuffer);
+    MPSImage* Y = mt.texture()->image();
+    MPSImageTranspose* transpose = [[MPSImageTranspose alloc]
+        initWithDevice:[MPSCNNContext sharedInstance].device];
+    [transpose encodeToCommandBuffer:commandBuffer.buffer
+                         sourceImage:X
+                    destinationImage:Y];
+    auto output = makeTensor(std::move(mt), input.options());
+    return output;
+  } else {
+    id<MTLBuffer> sizeBuf1 = makeMTLBuffer<ushort>(
+        std::vector<ushort>{input.sizes().begin(), input.sizes().end()});
+    id<MTLBuffer> sizeBuf2 = makeMTLBuffer<ushort>(
+        std::vector<ushort>{outputSizes.begin(), outputSizes.end()});
+    id<MTLBuffer> indexBuf = makeMTLBuffer(std::vector<ushort>(input.dim(), 1));
+    MetalTensorImplStorage mt{outputSizes};
+    mt.texture()->allocateTemporaryTextureStorage(outputSizes, commandBuffer);
+    MPSImage* Y = mt.texture()->image();
+    id<MTLComputeCommandEncoder> encoder =
+        [commandBuffer.buffer computeCommandEncoder];
+    id<MTLComputePipelineState> state =
+        [[MPSCNNContext sharedInstance] specializedPipelineState:@"transpose"
+                                                       Constants:@[
+                                                         @(dim0),
+                                                         @(dim1),
+                                                         @(input.dim()),
+                                                         @(X.numberOfImages),
+                                                         @(X.featureChannels),
+                                                         @(Y.numberOfImages),
+                                                         @(Y.featureChannels),
+                                                       ]];
+
+    [encoder setComputePipelineState:state];
+    [encoder setTexture:[X texture] atIndex:0];
+    [encoder setTexture:[Y texture] atIndex:1];
+    [encoder setBuffer:sizeBuf1 offset:0 atIndex:0];
+    [encoder setBuffer:sizeBuf2 offset:0 atIndex:1];
+    [encoder setBuffer:indexBuf offset:0 atIndex:2];
+
+    const auto& launchParams =
+        mpscnn::spatialPointwiseKernelLaunchParams(state, Y);
+    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
+            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
+    [encoder endEncoding];
+    [X markRead];
+    [Y markRead];
+
+    auto output = makeTensor(std::move(mt), input.options());
+    return output;
+  }
+}
+
 Tensor t(const Tensor& input) {
   TORCH_CHECK(input.is_metal());
-  TORCH_CHECK(input.is_metal());
   TORCH_CHECK(input.dim() == 2);
-  auto strides = input.strides().vec();
-  auto sizes = input.sizes().vec();
-  MPSImage* X = imageFromTensor(input);
-  TORCH_CHECK(X.numberOfImages == 1);
-  TORCH_CHECK(X.featureChannels == 1);
-  MetalTensorImplStorage mt({sizes[1], sizes[0]});
-  MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
-  mt.texture()->allocateTemporaryTextureStorage(
-      {1, 1, sizes[1], sizes[0]}, commandBuffer);
-  MPSImage* Y = mt.texture()->image();
-  MPSImageTranspose* transpose = [[MPSImageTranspose alloc]
-      initWithDevice:[MPSCNNContext sharedInstance].device];
-  [transpose encodeToCommandBuffer:commandBuffer.buffer
-                       sourceImage:X
-                  destinationImage:Y];
-  auto output = makeTensor(std::move(mt), input.options());
-  return output;
+  return metal::transpose(input, 0, input.dim() < 2 ? 0 : 1);
 }
 
+TORCH_LIBRARY_IMPL(aten, Metal, m) {
+  m.impl("t", TORCH_FN(t));
+  m.impl("transpose.int", TORCH_FN(transpose));
+};
+
 }
 }
 }