[iOS GPU][Kernel] Implement transpose in Metal shaders (#54522)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54522
Implement the transpose operator in metal shaders using textures.
ghstack-source-id: 126802125
Test Plan:
- Metal operator tests
```
2021-03-22 02:25:53.941006-0700 PyTorchPlayground[57924:9047121] [bool test_transpose()],[1 2 2 5 ],[SUCCEED]
2021-03-22 02:25:53.949834-0700 PyTorchPlayground[57924:9047121] [bool test_transpose2()],[1 2 58 28 28 ],[SUCCEED]
2021-03-22 03:12:19.786584-0700 PyTorchPlayground[58230:9066223] [bool test_transpose3()],[4 5 6 ],[SUCCEED]
```
- Sandcastle CI
- CircleCI
Reviewed By: SS-JIA
Differential Revision: D27225940
fbshipit-source-id: 14bfb96435a39aecf4f14bc5e2f7232421014328
diff --git a/aten/src/ATen/native/metal/MetalShaders.h b/aten/src/ATen/native/metal/MetalShaders.h
index c275038..dd87d20 100644
--- a/aten/src/ATen/native/metal/MetalShaders.h
+++ b/aten/src/ATen/native/metal/MetalShaders.h
@@ -590,7 +590,7 @@
// we compute the "linear index" of the output element,
// and convert it to the equivalent "linear index" of the input element.
ushort offset = 4 * s2 + idx;
- ushort linear_idx = n2 * C2 * H2 * W2 + offset * H2 * W2 + gid.y * W2 + gid.x;
+ int64_t linear_idx = n2 * C2 * H2 * W2 + offset * H2 * W2 + gid.y * W2 + gid.x;
if(linear_idx >= numel1){
value[idx] = 0;
continue;
@@ -615,6 +615,98 @@
}
}
+constant bool transpose_in_is_arr = (ushort_arg_3 > 1 || ushort_arg_4 > 4);
+constant bool transpose_in_is_tex = !transpose_in_is_arr;
+constant bool transpose_out_is_arr = (ushort_arg_5 > 1 || ushort_arg_6 > 4);
+constant bool transpose_out_is_tex = !transpose_out_is_arr;
+kernel void transpose(texture2d_array<half, access::read>in_arr[[texture(0),function_constant(transpose_in_is_arr)]],
+ texture2d<half, access::read> in_tex[[texture(0), function_constant(transpose_in_is_tex)]],
+ texture2d_array<half, access::write>out_arr[[texture(1),function_constant(transpose_out_is_arr)]],
+ texture2d<half, access::write> out_tex[[texture(1), function_constant(transpose_out_is_tex)]],
+ constant ushort* inSizeBuffer [[buffer(0)]],
+ constant ushort* outSizeBuffer [[buffer(1)]],
+ device ushort* indexBuffer [[buffer(2)]],
+ ushort3 gid[[thread_position_in_grid]]) {
+
+ const ushort dim0 = ushort_arg_0;
+ const ushort dim1 = ushort_arg_1;
+ const ushort dim = ushort_arg_2;
+ const ushort N1 = ushort_arg_3;
+ const ushort C1 = ushort_arg_4;
+ const ushort N2 = ushort_arg_5;
+ const ushort C2 = ushort_arg_6;
+ ushort W1,W2,H1,H2;
+ if(transpose_in_is_arr) {
+ W1 = in_arr.get_width();
+ H1 = in_arr.get_height();
+ } else {
+ W1 = in_tex.get_width();
+ H1 = in_tex.get_height();
+ }
+ if(transpose_out_is_arr) {
+ W2 = out_arr.get_width();
+ H2 = out_arr.get_height();
+ } else {
+ W2 = out_tex.get_width();
+ H2 = out_tex.get_height();
+ }
+ if (gid.x >= W2 || gid.y >= H2) {
+ return;
+ }
+ const int numel = H2 * W2 * C2 * N2;
+ const ushort slices2 = divRoundUp(C2, 4);
+ const ushort slices1 = divRoundUp(C1, 4);
+ const ushort n2 = gid.z / slices2;
+ const ushort s2 = gid.z - n2 * slices2;
+ half4 value;
+ for (int idx = 0; idx < 4; ++idx){
+ ushort offset = 4 * s2 + idx;
+ int64_t linear_idx2 = n2 * C2 * H2 * W2 + offset * H2 * W2 + gid.y * W2 + gid.x;
+ if(linear_idx2 >= numel) {
+ value[idx] = 0;
+ continue;
+ }
+
+ ushort d2 = 0;
+ for(int j = dim-1; j>=0; --j){
+ d2 = outSizeBuffer[j];
+ indexBuffer[j] = linear_idx2 % d2;
+ linear_idx2 /= d2;
+ }
+
+ // swap dims
+ ushort tmp = indexBuffer[dim0];
+ indexBuffer[dim0] = indexBuffer[dim1];
+ indexBuffer[dim1] = tmp;
+
+ int64_t linear_idx1 = 0;
+ ushort m = 1;
+ ushort d1 = 0;
+ for(int k = dim-1; k>=0; --k) {
+ d1 = indexBuffer[k];
+ linear_idx1 += d1 * m;
+ m *= inSizeBuffer[k];
+ }
+
+ auto x1 = linear_idx1 % W1;
+ auto y1 = ((int)(linear_idx1/W1)) % H1;
+ auto c1 = ((int)(linear_idx1/W1/H1) % C1);
+ auto n1 = ((int)(linear_idx1/W1/H1/C1) % N1);
+ auto z1 = (int)c1 / 4 + n1 * slices1;
+ auto pos = c1 % 4;
+ if(transpose_in_is_arr) {
+ value[idx] = in_arr.read(ushort2(x1, y1), z1)[pos];
+ } else {
+ value[idx] = in_tex.read(ushort2(x1, y1))[pos];
+ }
+ }
+ if(transpose_out_is_arr) {
+ out_arr.write(value, gid.xy, gid.z);
+ } else {
+ out_tex.write(value, gid.xy);
+ }
+}
+
)PT_METAL_SHADERS";
#endif /* MPSCNNShaders_h */
diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h
index 0dc7be4..85d796a 100644
--- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h
+++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h
@@ -24,6 +24,9 @@
bool test_div_broadcast();
bool test_div_broadcast2();
bool test_t();
+bool test_transpose();
+bool test_transpose2();
+bool test_transpose3();
bool test_view();
bool test_view2();
bool test_view3();
diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm
index aa18944..5951276 100644
--- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm
+++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm
@@ -4,7 +4,6 @@
#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
#import <ATen/native/metal/mpscnn/tests/MPSCNNTests.h>
#import <ATen/native/metal/ops/MetalConvolution.h>
-#import <ATen/native/metal/ops/MetalTranspose.h>
#import <Foundation/Foundation.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
@@ -490,7 +489,7 @@
auto X1 = at::rand({H, W}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
auto Y1 = at::t(X1).contiguous();
auto X2 = X1.metal();
- auto Y2 = at::native::metal::t(X2).cpu();
+ auto Y2 = at::t(X2).cpu();
return almostEqual(Y1, Y2);
});
if (!b) {
@@ -500,6 +499,39 @@
return result;
}
+bool test_transpose() {
+ __block std::vector<int64_t> size {1, 2, 2, 5};
+ return TEST(size, __PRETTY_FUNCTION__, ^bool{
+ auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
+ auto Y1 = at::transpose(X1, 1, 3).contiguous();
+ auto X2 = X1.metal();
+ auto Y2 = at::transpose(X2, 1, 3).cpu();
+ return almostEqual(Y1, Y2);
+ });
+}
+
+bool test_transpose2() {
+ __block std::vector<int64_t> size {1, 2, 58, 28, 28};
+ return TEST(size, __PRETTY_FUNCTION__, ^bool{
+ auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
+ auto Y1 = at::transpose(X1, 1, 2).contiguous();
+ auto X2 = X1.metal();
+ auto Y2 = at::transpose(X2, 1, 2).cpu();
+ return almostEqual(Y1, Y2);
+ });
+}
+
+bool test_transpose3() {
+ __block std::vector<int64_t> size {4, 5, 6};
+ return TEST(size, __PRETTY_FUNCTION__, ^bool{
+ auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
+ auto Y1 = at::transpose(X1, 2, 0).contiguous();
+ auto X2 = X1.metal();
+ auto Y2 = at::transpose(X2, 2, 0).cpu();
+ return almostEqual(Y1, Y2);
+ });
+}
+
bool test_view() {
// array -> array
__block std::vector<int64_t> size{1, 10, 2, 2};
diff --git a/aten/src/ATen/native/metal/ops/MetalTranspose.h b/aten/src/ATen/native/metal/ops/MetalTranspose.h
deleted file mode 100644
index 9353b6b..0000000
--- a/aten/src/ATen/native/metal/ops/MetalTranspose.h
+++ /dev/null
@@ -1,17 +0,0 @@
-#ifndef MetalCopy_h
-#define MetalCopy_h
-
-#include <ATen/Tensor.h>
-
-namespace at {
-namespace native {
-namespace metal {
-
-// TODO: Remove the header once we are able to call it through dispatcher
-Tensor t(const Tensor& input);
-
-} // namespace metal
-} // namespace native
-} // namespace at
-
-#endif
diff --git a/aten/src/ATen/native/metal/ops/MetalTranspose.mm b/aten/src/ATen/native/metal/ops/MetalTranspose.mm
index e174609..f2b791a 100644
--- a/aten/src/ATen/native/metal/ops/MetalTranspose.mm
+++ b/aten/src/ATen/native/metal/ops/MetalTranspose.mm
@@ -3,37 +3,94 @@
#import <ATen/native/metal/MetalTensorImplStorage.h>
#import <ATen/native/metal/MetalUtils.h>
#import <ATen/native/metal/mpscnn/MPSCNNContext.h>
+#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
+
+#include <ATen/ATen.h>
#include <torch/library.h>
namespace at {
namespace native {
namespace metal {
+Tensor transpose(const Tensor& input, int64_t dim0, int64_t dim1) {
+ TORCH_CHECK(input.is_metal());
+ auto ndims = input.dim();
+ dim0 = maybe_wrap_dim(dim0, ndims);
+ dim1 = maybe_wrap_dim(dim1, ndims);
+ if (dim0 == dim1) {
+ return input;
+ }
+ auto outputSizes = input.sizes().vec();
+ std::swap(outputSizes[dim0], outputSizes[dim1]);
+ MPSImage* X = imageFromTensor(input);
+ MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
+ if (input.dim() == 2) {
+ MetalTensorImplStorage mt{outputSizes};
+ mt.texture()->allocateTemporaryTextureStorage(outputSizes, commandBuffer);
+ MPSImage* Y = mt.texture()->image();
+ MPSImageTranspose* transpose = [[MPSImageTranspose alloc]
+ initWithDevice:[MPSCNNContext sharedInstance].device];
+ [transpose encodeToCommandBuffer:commandBuffer.buffer
+ sourceImage:X
+ destinationImage:Y];
+ auto output = makeTensor(std::move(mt), input.options());
+ return output;
+ } else {
+ id<MTLBuffer> sizeBuf1 = makeMTLBuffer<ushort>(
+ std::vector<ushort>{input.sizes().begin(), input.sizes().end()});
+ id<MTLBuffer> sizeBuf2 = makeMTLBuffer<ushort>(
+ std::vector<ushort>{outputSizes.begin(), outputSizes.end()});
+ id<MTLBuffer> indexBuf = makeMTLBuffer(std::vector<ushort>(input.dim(), 1));
+ MetalTensorImplStorage mt{outputSizes};
+ mt.texture()->allocateTemporaryTextureStorage(outputSizes, commandBuffer);
+ MPSImage* Y = mt.texture()->image();
+ id<MTLComputeCommandEncoder> encoder =
+ [commandBuffer.buffer computeCommandEncoder];
+ id<MTLComputePipelineState> state =
+ [[MPSCNNContext sharedInstance] specializedPipelineState:@"transpose"
+ Constants:@[
+ @(dim0),
+ @(dim1),
+ @(input.dim()),
+ @(X.numberOfImages),
+ @(X.featureChannels),
+ @(Y.numberOfImages),
+ @(Y.featureChannels),
+ ]];
+
+ [encoder setComputePipelineState:state];
+ [encoder setTexture:[X texture] atIndex:0];
+ [encoder setTexture:[Y texture] atIndex:1];
+ [encoder setBuffer:sizeBuf1 offset:0 atIndex:0];
+ [encoder setBuffer:sizeBuf2 offset:0 atIndex:1];
+ [encoder setBuffer:indexBuf offset:0 atIndex:2];
+
+ const auto& launchParams =
+ mpscnn::spatialPointwiseKernelLaunchParams(state, Y);
+ [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
+ threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
+ [encoder endEncoding];
+ [X markRead];
+ [Y markRead];
+
+ auto output = makeTensor(std::move(mt), input.options());
+ return output;
+ }
+}
+
Tensor t(const Tensor& input) {
TORCH_CHECK(input.is_metal());
- TORCH_CHECK(input.is_metal());
TORCH_CHECK(input.dim() == 2);
- auto strides = input.strides().vec();
- auto sizes = input.sizes().vec();
- MPSImage* X = imageFromTensor(input);
- TORCH_CHECK(X.numberOfImages == 1);
- TORCH_CHECK(X.featureChannels == 1);
- MetalTensorImplStorage mt({sizes[1], sizes[0]});
- MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
- mt.texture()->allocateTemporaryTextureStorage(
- {1, 1, sizes[1], sizes[0]}, commandBuffer);
- MPSImage* Y = mt.texture()->image();
- MPSImageTranspose* transpose = [[MPSImageTranspose alloc]
- initWithDevice:[MPSCNNContext sharedInstance].device];
- [transpose encodeToCommandBuffer:commandBuffer.buffer
- sourceImage:X
- destinationImage:Y];
- auto output = makeTensor(std::move(mt), input.options());
- return output;
+ return metal::transpose(input, 0, input.dim() < 2 ? 0 : 1);
}
+TORCH_LIBRARY_IMPL(aten, Metal, m) {
+ m.impl("t", TORCH_FN(t));
+ m.impl("transpose.int", TORCH_FN(transpose));
+};
+
}
}
}