[rocm][cmake] retrieve rocm location from ROCM_SOURCE_DIR env if specified (#120898)
This PR allows us to build PyTorch with a rocm that is not installed
to the default location, i.e. /opt/rocm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120898
Approved by: https://github.com/jianyuh
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 1758210..3a21b3a 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -1325,11 +1325,19 @@
USE_ROCM
__HIP_PLATFORM_AMD__
)
+
+ if(NOT ROCM_SOURCE_DIR)
+ set(ROCM_SOURCE_DIR "$ENV{ROCM_SOURCE_DIR}")
+ endif()
+ if($ROCM_SOURCE_DIR STREQUAL "")
+ set(ROCM_SOURCE_DIR "/opt/rocm")
+ endif()
+ message(INFO "caffe2 ROCM_SOURCE_DIR = ${ROCM_SOURCE_DIR}")
target_include_directories(torch_hip PRIVATE
- /opt/rocm/include
- /opt/rocm/hcc/include
- /opt/rocm/rocblas/include
- /opt/rocm/hipsparse/include
+ ${ROCM_SOURCE_DIR}/include
+ ${ROCM_SOURCE_DIR}/hcc/include
+ ${ROCM_SOURCE_DIR}/rocblas/include
+ ${ROCM_SOURCE_DIR}/hipsparse/include
)
if(USE_FLASH_ATTENTION)
target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION)