Fix bug in caffe2 transpose on GPU (#22233)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22233
Fix bug in caffe2 transpose on GPU
Reviewed By: hl475
Differential Revision: D15994973
fbshipit-source-id: 542dc8757b51a6322fffa55826c1d4e32927398d
diff --git a/caffe2/python/operator_test/transpose_op_test.py b/caffe2/python/operator_test/transpose_op_test.py
index e8aa460..0610476 100644
--- a/caffe2/python/operator_test/transpose_op_test.py
+++ b/caffe2/python/operator_test/transpose_op_test.py
@@ -38,6 +38,19 @@
self.assertDeviceChecks(dc, op, [X], [0])
self.assertGradientChecks(gc, op, [X], 0, [0])
+ @given(M=st.integers(10, 200), N=st.integers(10, 200), **hu.gcs)
+ def test_transpose_large_matrix(self, M, N, gc, dc):
+ op = core.CreateOperator("Transpose", ["X"], ["Y"], device_option=gc)
+ X = np.random.rand(M, N).astype(np.float32) - 0.5
+
+ def transpose_ref(X):
+ return [np.transpose(X)]
+
+ self.assertReferenceChecks(gc, op, [X], transpose_ref)
+ self.assertDeviceChecks(dc, op, [X], [0])
+ self.assertGradientChecks(gc, op, [X], 0, [0])
+
+
@unittest.skipIf(not workspace.has_cuda_support, "no cuda support")
@given(X=hu.tensor(dtype=np.float32), use_axes=st.booleans(),
**hu.gcs_cuda_only)
diff --git a/caffe2/utils/math/transpose.cu b/caffe2/utils/math/transpose.cu
index 48a6fa2..a02a8d2 100644
--- a/caffe2/utils/math/transpose.cu
+++ b/caffe2/utils/math/transpose.cu
@@ -37,7 +37,7 @@
int x = c * kTileDim + threadIdx.x;
int y = r * kTileDim + threadIdx.y;
if (x < W) {
- for (int i = 0; i < kTileDim && y + i < H; i += kBlockRows) {
+ for (int i = 0; threadIdx.y + i < kTileDim && y + i < H; i += kBlockRows) {
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
tile[threadIdx.y + i][threadIdx.x] = __ldg(X + offset + (y + i) * W + x);
#else
@@ -49,7 +49,7 @@
x = r * kTileDim + threadIdx.x;
y = c * kTileDim + threadIdx.y;
if (x < H) {
- for (int i = 0; i < kTileDim && y + i < W; i += kBlockRows) {
+ for (int i = 0; threadIdx.y + i < kTileDim && y + i < W; i += kBlockRows) {
Y[offset + (y + i) * H + x] = tile[threadIdx.x][threadIdx.y + i];
}
}