Fix upsample kernel launch / reorder arguments (#20505)
Summary:
this is a follow up for https://github.com/pytorch/pytorch/pull/19630
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20505
Differential Revision: D15392706
Pulled By: ezyang
fbshipit-source-id: 5a8a7aacdbcf740508baf2b6e0c081c4e5a0390f
diff --git a/aten/src/ATen/native/cuda/UpSample.cuh b/aten/src/ATen/native/cuda/UpSample.cuh
index 8978134..c0842f8 100644
--- a/aten/src/ATen/native/cuda/UpSample.cuh
+++ b/aten/src/ATen/native/cuda/UpSample.cuh
@@ -158,51 +158,42 @@
const float scale,
int dst_index,
int input_size) {
- const int src_index = min<int>(
- static_cast<int>(floorf(dst_index * scale)), input_size - 1);
+ const int src_index =
+ min<int>(static_cast<int>(floorf(dst_index * scale)), input_size - 1);
return src_index;
}
-/* just affect UpSampleBicubic2d.cu */
-/* TODO: change width and height order in the arguments */
-/* TODO: maybe change x and y order in the arguments */
-/* TODO: maybe change channel and batch order in the arguments */
+/* Used by UpSampleBicubic2d.cu */
template <typename scalar_t>
__device__ __forceinline__ static scalar_t upsample_get_value_bounded(
const PackedTensorAccessor<scalar_t, 4>& data,
- int channel,
int batch,
- int width,
+ int channel,
int height,
- int x,
- int y) {
- int access_x =
- max<int>(min<int>(x, width - 1), static_cast<int>(0));
- int access_y =
- max<int>(min<int>(y, height - 1), static_cast<int>(0));
+ int width,
+ int y,
+ int x) {
+ int access_y = max<int>(min<int>(y, height - 1), 0);
+ int access_x = max<int>(min<int>(x, width - 1), 0);
return data[batch][channel][access_y][access_x];
}
-/* just affect UpSampleBicubic2d.cu */
-/* TODO: change width and height order in the arguments */
-/* TODO: maybe change x and y order in the arguments */
-/* TODO: maybe change channel and batch order in the arguments */
+/* Used by UpSampleBicubic2d.cu */
template <typename scalar_t, typename accscalar_t>
__device__ __forceinline__ static void upsample_increment_value_bounded(
PackedTensorAccessor<scalar_t, 4>& data,
- int channel,
int batch,
- int width,
+ int channel,
int height,
- int x,
+ int width,
int y,
+ int x,
accscalar_t value) {
- int access_x =
- max<int>(min<int>(x, width - 1), static_cast<int>(0));
- int access_y =
- max<int>(min<int>(y, height - 1), static_cast<int>(0));
- /* TODO: result here is trucated to scalar_t,
- check: https://github.com/pytorch/pytorch/pull/19630#discussion_r281426912 */
+ int access_y = max<int>(min<int>(y, height - 1), 0);
+ int access_x = max<int>(min<int>(x, width - 1), 0);
+ /* TODO: result here is trucated to scalar_t,
+ check: https://github.com/pytorch/pytorch/pull/19630#discussion_r281426912
+ */
atomicAdd(
&data[batch][channel][access_y][access_x], static_cast<scalar_t>(value));
}
diff --git a/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu b/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu
index ee632aa..d3b02a0 100644
--- a/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu
+++ b/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu
@@ -12,9 +12,7 @@
namespace {
template <typename scalar_t, typename accscalar_t>
-#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(1024)
-#endif
__global__ void upsample_bicubic2d_out_frame(
const int num_elements,
const accscalar_t height_scale,
@@ -65,18 +63,15 @@
accscalar_t coefficients[4];
for (int k = 0; k < 4; k++) {
- /* TODO: change width and height order in the arguments */
- /* TODO: maybe change x and y order in the arguments */
- /* TODO: maybe change c and n order in the arguments */
coefficients[k] = cubic_interp1d(
upsample_get_value_bounded<scalar_t>(
- idata, c, n, input_width, input_height, in_x - 1, in_y - 1 + k),
+ idata, n, c, input_height, input_width, in_y - 1 + k, in_x - 1),
upsample_get_value_bounded<scalar_t>(
- idata, c, n, input_width, input_height, in_x + 0, in_y - 1 + k),
+ idata, n, c, input_height, input_width, in_y - 1 + k, in_x + 0),
upsample_get_value_bounded<scalar_t>(
- idata, c, n, input_width, input_height, in_x + 1, in_y - 1 + k),
+ idata, n, c, input_height, input_width, in_y - 1 + k, in_x + 1),
upsample_get_value_bounded<scalar_t>(
- idata, c, n, input_width, input_height, in_x + 2, in_y - 1 + k),
+ idata, n, c, input_height, input_width, in_y - 1 + k, in_x + 2),
t_x);
}
@@ -92,9 +87,7 @@
// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename scalar_t, typename accscalar_t>
-#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(1024)
-#endif
__global__ void upsample_bicubic2d_backward_out_frame(
const int num_elements,
const accscalar_t height_scale,
@@ -102,7 +95,6 @@
const bool align_corners,
PackedTensorAccessor<scalar_t, 4> idata,
const PackedTensorAccessor<scalar_t, 4> odata) {
-
int index = threadIdx.x + blockIdx.x * blockDim.x;
const int batchsize = idata.size(0);
@@ -150,16 +142,14 @@
scalar_t out_value = odata[n][c][output_y][output_x];
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
- /* TODO: change width and height order in the arguments */
- /* TODO: maybe change x and y order in the arguments */
upsample_increment_value_bounded<scalar_t, accscalar_t>(
idata,
- c,
n,
- input_width,
+ c,
input_height,
- input_x - 1 + j,
+ input_width,
input_y - 1 + i,
+ input_x - 1 + j,
out_value * y_coeffs[i] * x_coeffs[j]);
}
}
@@ -206,8 +196,8 @@
output_width > 0);
const int num_output_elements = output_height * output_width;
- const int max_threads =
- at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock / 2;
+ const int max_threads = std::min(
+ at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
// Launch kernel
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -287,8 +277,8 @@
grad_input.zero_();
const int num_kernels = output_height * output_width;
- const int num_threads =
- at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock / 2;
+ const int num_threads = std::min(
+ at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
diff --git a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu
index 0931ef6..0bb523a 100644
--- a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu
+++ b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu
@@ -14,9 +14,7 @@
namespace {
template <typename scalar_t, typename accscalar_t>
-#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(1024)
-#endif
__global__ void upsample_bilinear2d_out_frame(
const int n,
const accscalar_t rheight,
@@ -79,9 +77,7 @@
// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename scalar_t, typename accscalar_t>
-#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(1024)
-#endif
__global__ void upsample_bilinear2d_backward_out_frame(
const int n,
const accscalar_t rheight,
@@ -187,8 +183,8 @@
output_width > 0);
const int num_kernels = output_height * output_width;
- const int num_threads =
- at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
+ const int num_threads = std::min(
+ at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@@ -211,7 +207,7 @@
num_kernels, rheight, rwidth, align_corners, idata, odata);
});
- AT_CUDA_CHECK(cudaGetLastError());
+ AT_CUDA_CHECK(cudaGetLastError());
}
static void upsample_bilinear2d_backward_out_cuda_template(
@@ -260,8 +256,8 @@
grad_input.zero_();
const int num_kernels = output_height * output_width;
- const int num_threads =
- at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
+ const int num_threads = std::min(
+ at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@@ -284,7 +280,7 @@
num_kernels, rheight, rwidth, align_corners, idata, odata);
});
- AT_CUDA_CHECK(cudaGetLastError());
+ AT_CUDA_CHECK(cudaGetLastError());
}
} // namespace
diff --git a/aten/src/ATen/native/cuda/UpSampleNearest1d.cu b/aten/src/ATen/native/cuda/UpSampleNearest1d.cu
index dc96028..831c37d 100644
--- a/aten/src/ATen/native/cuda/UpSampleNearest1d.cu
+++ b/aten/src/ATen/native/cuda/UpSampleNearest1d.cu
@@ -12,9 +12,7 @@
namespace {
template <typename scalar_t, typename accscalar_t>
-#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(1024)
-#endif
__global__ void upsample_nearest1d_out_frame(
const int n,
const PackedTensorAccessor<scalar_t, 3> idata,
@@ -55,9 +53,7 @@
// Backward operation
template <typename scalar_t, typename accscalar_t>
-#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(1024)
-#endif
__global__ void upsample_nearest1d_backward_out_frame(
const int n,
PackedTensorAccessor<scalar_t, 3> idata,
@@ -123,8 +119,8 @@
output.zero_();
const int num_kernels = output_width;
- const int num_threads =
- at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
+ const int num_threads = std::min(
+ at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@@ -141,7 +137,7 @@
stream>>>(num_kernels, idata, odata);
});
- AT_CUDA_CHECK(cudaGetLastError());
+ AT_CUDA_CHECK(cudaGetLastError());
}
static void upsample_nearest1d_backward_out_cuda_template(
@@ -179,8 +175,8 @@
grad_input.zero_();
const int num_kernels = output_width;
- const int num_threads =
- at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
+ const int num_threads = std::min(
+ at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@@ -197,7 +193,7 @@
stream>>>(num_kernels, idata, odata);
});
- AT_CUDA_CHECK(cudaGetLastError());
+ AT_CUDA_CHECK(cudaGetLastError());
}
} // namespace
diff --git a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu
index 6c6950c..f09c3c7 100644
--- a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu
+++ b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu
@@ -12,9 +12,7 @@
namespace {
template <typename scalar_t, typename accscalar_t>
-#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(1024)
-#endif
__global__ void upsample_nearest2d_out_frame(
const int n,
const PackedTensorAccessor<scalar_t, 4> idata,
@@ -64,9 +62,7 @@
// Backward operation
template <typename scalar_t, typename accscalar_t>
-#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(1024)
-#endif
__global__ void upsample_nearest2d_backward_out_frame(
const int n,
PackedTensorAccessor<scalar_t, 4> idata,
@@ -153,8 +149,8 @@
output.zero_();
const int num_kernels = output_height * output_width;
- const int num_threads =
- at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
+ const int num_threads = std::min(
+ at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@@ -171,7 +167,7 @@
stream>>>(num_kernels, idata, odata);
});
- AT_CUDA_CHECK(cudaGetLastError());
+ AT_CUDA_CHECK(cudaGetLastError());
}
static void upsample_nearest2d_backward_out_cuda_template(
@@ -219,8 +215,8 @@
grad_input.zero_();
const int num_kernels = output_height * output_width;
- const int num_threads =
- at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
+ const int num_threads = std::min(
+ at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@@ -237,7 +233,7 @@
stream>>>(num_kernels, idata, odata);
});
- AT_CUDA_CHECK(cudaGetLastError());
+ AT_CUDA_CHECK(cudaGetLastError());
}
} // namespace
diff --git a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu
index f0324ca..26cc987 100644
--- a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu
+++ b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu
@@ -12,9 +12,7 @@
namespace {
template <typename scalar_t, typename accscalar_t>
-#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(1024)
-#endif
__global__ void upsample_nearest3d_out_frame(
const int n,
const PackedTensorAccessor<scalar_t, 5> idata,
@@ -71,9 +69,7 @@
// Backward operation
template <typename scalar_t, typename accscalar_t>
-#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(1024)
-#endif
__global__ void upsample_nearest3d_backward_out_frame(
const int n,
PackedTensorAccessor<scalar_t, 5> idata,
@@ -174,8 +170,8 @@
output.zero_();
const int num_kernels = output_depth * output_height * output_width;
- const int num_threads =
- at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
+ const int num_threads = std::min(
+ at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@@ -192,7 +188,7 @@
stream>>>(num_kernels, idata, odata);
});
- AT_CUDA_CHECK(cudaGetLastError());
+ AT_CUDA_CHECK(cudaGetLastError());
}
static void upsample_nearest3d_backward_out_cuda_template(
@@ -244,8 +240,8 @@
grad_input.zero_();
const int num_kernels = output_depth * output_height * output_width;
- const int num_threads =
- at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
+ const int num_threads = std::min(
+ at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@@ -262,7 +258,7 @@
stream>>>(num_kernels, idata, odata);
});
- AT_CUDA_CHECK(cudaGetLastError());
+ AT_CUDA_CHECK(cudaGetLastError());
}
} // namespace
diff --git a/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu b/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu
index 5e0b003..2b94c85 100644
--- a/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu
+++ b/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu
@@ -237,8 +237,8 @@
output_depth > 0 && output_height > 0 && output_width > 0);
const int num_kernels = output_depth * output_height * output_width;
- const int num_threads =
- at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
+ const int num_threads = std::min(
+ at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@@ -269,7 +269,7 @@
odata);
});
- AT_CUDA_CHECK(cudaGetLastError());
+ AT_CUDA_CHECK(cudaGetLastError());
}
static void upsample_trilinear3d_backward_out_cuda_template(
@@ -322,8 +322,8 @@
grad_input.zero_();
const int num_kernels = output_depth * output_height * output_width;
- const int num_threads =
- at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
+ const int num_threads = std::min(
+ at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@@ -356,7 +356,7 @@
odata);
});
- AT_CUDA_CHECK(cudaGetLastError());
+ AT_CUDA_CHECK(cudaGetLastError());
}
} // namespace