[ROCm] enable InTopK op on ROCm.
diff --git a/tensorflow/core/kernels/in_topk_op.cc b/tensorflow/core/kernels/in_topk_op.cc
index a8ee00e..22d8333 100644
--- a/tensorflow/core/kernels/in_topk_op.cc
+++ b/tensorflow/core/kernels/in_topk_op.cc
@@ -116,7 +116,7 @@
.TypeConstraint<int64>("T"),
InTopK<CPUDevice, float, int64>);
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Forward declarations of the functor specializations for GPU.
namespace functor {
@@ -142,6 +142,6 @@
Name("InTopKV2").Device(DEVICE_GPU).TypeConstraint<int64>("T"),
InTopK<GPUDevice, float, int64>);
-#endif // GOOGLE_CUDA
+#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/in_topk_op.h b/tensorflow/core/kernels/in_topk_op.h
index 52716f2..f48932c 100644
--- a/tensorflow/core/kernels/in_topk_op.h
+++ b/tensorflow/core/kernels/in_topk_op.h
@@ -16,9 +16,9 @@
#ifndef TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
#define TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
-#if GOOGLE_CUDA
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
-#endif // GOOGLE_CUDA
+#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"
diff --git a/tensorflow/core/kernels/in_topk_op_gpu.cu.cc b/tensorflow/core/kernels/in_topk_op_gpu.cu.cc
index 1894ded..4c59e1f 100644
--- a/tensorflow/core/kernels/in_topk_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/in_topk_op_gpu.cu.cc
@@ -13,7 +13,7 @@
limitations under the License.
==============================================================================*/
-#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
+#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@@ -41,7 +41,7 @@
const TargetT* targets, // dims: [ num_targets ]
int64* mask, // dims: [ num_targets x num_classes ]
int num_targets, int num_classes) {
- CUDA_1D_KERNEL_LOOP(i, num_targets * num_classes) {
+ GPU_1D_KERNEL_LOOP(i, num_targets * num_classes) {
const int batch_index = i / num_classes;
TargetT target_idx = ldg(targets + batch_index);
@@ -118,7 +118,7 @@
const auto& d = context->eigen_device<GPUDevice>();
// Compute a mask for all predictions.
- CudaLaunchConfig config = GetGpuLaunchConfig(num_targets * num_classes, d);
+ GpuLaunchConfig config = GetGpuLaunchConfig(num_targets * num_classes, d);
OP_REQUIRES_OK(
context, GpuLaunchKernel(ComputePredictionMaskKernel<T, TargetT>,
config.block_count, config.thread_per_block, 0,
@@ -173,4 +173,4 @@
} // end namespace tensorflow
-#endif // GOOGLE_CUDA
+#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM