Small fixes for hipification (#31200)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31200
We do not hipify these files when doing out of place.
Test Plan: wait for CI to clear.
Differential Revision: D18963683
fbshipit-source-id: eeba8597143f26417d0a8181a4c746139afefa24
diff --git a/aten/src/ATen/native/Distributions.h b/aten/src/ATen/native/Distributions.h
index 62e7cb1..865d295 100644
--- a/aten/src/ATen/native/Distributions.h
+++ b/aten/src/ATen/native/Distributions.h
@@ -4,7 +4,7 @@
#include <c10/macros/Macros.h>
// ROCM hcc doesn't work well with using std:: in kernel functions
-#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
+#if defined(__CUDA_ARCH__)
#include <c10/cuda/CUDAMathCompat.h>
#define compat_exp c10::cuda::compat::exp
#define compat_floor c10::cuda::compat::floor
@@ -12,6 +12,14 @@
#define compat_pow c10::cuda::compat::pow
#define compat_sqrt c10::cuda::compat::sqrt
#define compat_tan c10::cuda::compat::tan
+#elif defined(__HIPCC__)
+#include <c10/hip/HIPMathCompat.h>
+#define compat_exp c10::hip::compat::exp
+#define compat_floor c10::hip::compat::floor
+#define compat_log c10::hip::compat::log
+#define compat_pow c10::hip::compat::pow
+#define compat_sqrt c10::hip::compat::sqrt
+#define compat_tan c10::hip::compat::tan
#else
#define compat_exp std::exp
#define compat_floor std::floor
diff --git a/aten/src/ATen/native/SharedReduceOps.h b/aten/src/ATen/native/SharedReduceOps.h
index 1cca15c..ab66d20 100644
--- a/aten/src/ATen/native/SharedReduceOps.h
+++ b/aten/src/ATen/native/SharedReduceOps.h
@@ -28,9 +28,12 @@
#endif
// ROCM hcc doesn't work well with using std:: in kernel functions
-#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
+#if defined(__CUDA_ARCH__)
#include <c10/cuda/CUDAMathCompat.h>
#define compat_pow c10::cuda::compat::pow
+#elif defined(__HIPCC__)
+#include <c10/hip/HIPMathCompat.h>
+#define compat_pow c10::hip::compat::pow
#else
#define compat_pow std::pow
#endif