Skip to content

Commit

Permalink
Update FindCUDNN.cmake for cuDNN 9 (#640)
Browse files Browse the repository at this point in the history
* update cudnn cmake for v9

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back license information

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
  • Loading branch information
cyanguwa authored Jan 31, 2024
1 parent 70bd26e commit e2803b1
Showing 1 changed file with 57 additions and 25 deletions.
82 changes: 57 additions & 25 deletions transformer_engine/cmake/FindCUDNN.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,29 @@ find_path(
CUDNN_INCLUDE_DIR cudnn.h
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_INCLUDE_DIRS}
PATH_SUFFIXES include
REQUIRED
)

function(find_cudnn_library NAME)
string(TOUPPER ${NAME} UPPERCASE_NAME)
file(READ "${CUDNN_INCLUDE_DIR}/cudnn_version.h" cudnn_version_header)
string(REGEX MATCH "#define CUDNN_MAJOR [1-9]+" macrodef "${cudnn_version_header}")
string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}")

function(find_cudnn_library NAME)
find_library(
${UPPERCASE_NAME}_LIBRARY ${NAME}
${NAME}_LIBRARY ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}"
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_LIBRARY_DIR}
PATH_SUFFIXES lib64 lib/x64 lib
REQUIRED
)

if(${UPPERCASE_NAME}_LIBRARY)
if(${NAME}_LIBRARY)
add_library(CUDNN::${NAME} UNKNOWN IMPORTED)
set_target_properties(
CUDNN::${NAME} PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}
IMPORTED_LOCATION ${${UPPERCASE_NAME}_LIBRARY}
IMPORTED_LOCATION ${${NAME}_LIBRARY}
)
message(STATUS "${NAME} found at ${${UPPERCASE_NAME}_LIBRARY}.")
message(STATUS "${NAME} found at ${${NAME}_LIBRARY}.")
else()
message(STATUS "${NAME} not found.")
endif()
Expand All @@ -35,24 +39,18 @@ function(find_cudnn_library NAME)
endfunction()

find_cudnn_library(cudnn)
find_cudnn_library(cudnn_adv_infer)
find_cudnn_library(cudnn_adv_train)
find_cudnn_library(cudnn_cnn_infer)
find_cudnn_library(cudnn_cnn_train)
find_cudnn_library(cudnn_ops_infer)
find_cudnn_library(cudnn_ops_train)

include (FindPackageHandleStandardArgs)
find_package_handle_standard_args(
CUDNN REQUIRED_VARS
CUDNN_INCLUDE_DIR CUDNN_LIBRARY
LIBRARY REQUIRED_VARS
CUDNN_INCLUDE_DIR cudnn_LIBRARY
)

if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY)
if(CUDNN_INCLUDE_DIR AND cudnn_LIBRARY)

message(STATUS "cuDNN: ${CUDNN_LIBRARY}")
message(STATUS "cuDNN: ${cudnn_LIBRARY}")
message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}")

set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found")

else()
Expand All @@ -71,11 +69,45 @@ target_include_directories(
target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn_adv_train
CUDNN::cudnn_ops_train
CUDNN::cudnn_cnn_train
CUDNN::cudnn_adv_infer
CUDNN::cudnn_cnn_infer
CUDNN::cudnn_ops_infer
CUDNN::cudnn
CUDNN::cudnn
)

if(CUDNN_MAJOR_VERSION EQUAL 8)
find_cudnn_library(cudnn_adv_infer)
find_cudnn_library(cudnn_adv_train)
find_cudnn_library(cudnn_cnn_infer)
find_cudnn_library(cudnn_cnn_train)
find_cudnn_library(cudnn_ops_infer)
find_cudnn_library(cudnn_ops_train)

target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn_adv_train
CUDNN::cudnn_ops_train
CUDNN::cudnn_cnn_train
CUDNN::cudnn_adv_infer
CUDNN::cudnn_cnn_infer
CUDNN::cudnn_ops_infer
)
elseif(CUDNN_MAJOR_VERSION EQUAL 9)
find_cudnn_library(cudnn_cnn)
find_cudnn_library(cudnn_adv)
find_cudnn_library(cudnn_graph)
find_cudnn_library(cudnn_ops)
find_cudnn_library(cudnn_engines_runtime_compiled)
find_cudnn_library(cudnn_engines_precompiled)
find_cudnn_library(cudnn_heuristic)

target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn_adv
CUDNN::cudnn_ops
CUDNN::cudnn_cnn
CUDNN::cudnn_graph
CUDNN::cudnn_engines_runtime_compiled
CUDNN::cudnn_engines_precompiled
CUDNN::cudnn_heuristic
)
endif()

0 comments on commit e2803b1

Please sign in to comment.