blob: a14b2e1b0e8b444f26f2c7e332c5201d4acbfe9a [file] [log] [blame]
# FindTorch
# -------
#
# Finds the Torch library
#
# This will define the following variables:
#
# TORCH_FOUND -- True if the system has the Torch library
# TORCH_INCLUDE_DIRS -- The include directories for torch
# TORCH_LIBRARIES -- Libraries to link to
#
# and the following imported targets:
#
# Torch
#
# and the following functions:
#
# torch_add_custom_op_library(<name> <source_files>)
SET(TORCH_ROOT "${CMAKE_CURRENT_LIST_DIR}/../")
set(TORCH_INCLUDE_DIRS
"${TORCH_ROOT}"
"${TORCH_ROOT}/aten/src"
"${CMAKE_CURRENT_LIST_DIR}/aten/src"
"${CMAKE_CURRENT_LIST_DIR}/caffe2/aten/src"
"${CMAKE_CURRENT_LIST_DIR}/caffe2/aten/src/TH"
)
find_library(TORCH_LIBRARY torch PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
find_library(CAFFE2_LIBRARY caffe2 PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
if (@USE_CUDA@)
find_package(CUDA REQUIRED)
find_library(CAFFE2_CUDA_LIBRARY caffe2_gpu PATHS "${CMAKE_CURRENT_LIST_DIR}/lib" NO_DEFAULT_PATH)
set(TORCH_CUDA_LIBRARIES -L${CUDA_TOOLKIT_ROOT_DIR}/lib64 cuda nvrtc cudart nvToolsExt)
list(APPEND TORCH_INCLUDE_DIRS ${CUDA_TOOLKIT_INCLUDE})
endif()
set(TORCH_LIBRARIES
${TORCH_LIBRARY}
${CAFFE2_LIBRARY}
${CAFFE2_CUDA_LIBRARY}
${TORCH_CUDA_LIBRARIES})
# Creates a shared library <name> with the correct include directories
# and linker flags set to include Torch header files and link with Torch
# libraries. Also sets the C++ standard version to C++11. All options
# can be override by specifying further options on the `<name>` CMake target.
function(torch_add_custom_op_library name source_files)
add_library(${name} SHARED ${source_files})
target_include_directories(${name} PUBLIC "${TORCH_INCLUDE_DIRS}")
target_link_libraries(${name} "${TORCH_LIBRARIES}")
target_compile_options(${name} PUBLIC -std=c++11)
endfunction(torch_add_custom_op_library)