Skip to content

Commit 553037f

Browse files
add rocm support:
- resovle nvbench problem - add hip cuda defs and port test_norm - add test_norm & bench_norm
1 parent f2ca781 commit 553037f

18 files changed

+709
-20
lines changed

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
[submodule "3rdparty/nvbench"]
22
path = 3rdparty/nvbench
33
url = https://github.com/NVIDIA/nvbench.git
4+
[submodule "3rdparty/hipbench"]
5+
path = 3rdparty/hipbench
6+
url = https://github.com/ROCm/hipBench.git
47
[submodule "3rdparty/googletest"]
58
path = 3rdparty/googletest
69
url = https://github.com/google/googletest.git

CMakeLists.txt

+159-12
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,72 @@
1-
cmake_minimum_required(VERSION 3.23.1)
2-
project(flashinfer CUDA CXX)
1+
cmake_minimum_required(VERSION 3.26.4)
2+
3+
# set compiler conditional
4+
# Verified for ROCM >= 6.2, alias to $hip_LIB_INSTALL_DIR defined in ${ROCM_HOME}/lib/cmake/hip/hip-config-amd.cmake
5+
set(ROCM_HOME "/opt/rocm" CACHE PATH "ROCM SDK INSTALLATION HOME")
6+
if (NOT IS_DIRECTORY ${ROCM_HOME})
7+
message(WARNING "ROCM_HOME ${ROCM_HOME} is not a directory")
8+
endif()
9+
10+
if (LINUX)
11+
# SDK Root in CMAKE config file; LINUX system defaults to ENV{ROCM_PATH}; WIN32 system defaults to ENV{HIP_PATH}
12+
set(ENV{ROCM_PATH} ${ROCM_HOME})
13+
endif()
14+
15+
if(NOT DEFINED HIP_CMAKE_PATH)
16+
if(NOT DEFINED ENV{HIP_CMAKE_PATH})
17+
# NOTE(yiakwy) : find_package(HIP) will first search for
18+
# cmake/Modules/FindAMDDeviceLibs.cmake
19+
# , then
20+
# /opt/rocm/lib/cmake/AMDDeviceLibs/AMDDeviceLibsConfig.cmake
21+
# this will add hip::host, hip::device dependencies to be linked by any hip targets (ROCM >= 6.x).
22+
# Add hip-config.cmake to CMake module search path.
23+
# set(HIP_CMAKE_PATH "${ROCM_HOME}/share/rocm/cmake" "${ROCM_HOME}/share/rocmcmakebuildtools/cmake/" CACHE PATH "Path to which HIP has been installed")
24+
# NOTE(yiakwy) : adding ${ROCM_HOME}/lib/cmake/hip has conflicts with 3rdparty/mscclpp
25+
set(HIP_CMAKE_PATH "${ROCM_HOME}/lib/cmake/AMDDeviceLibs" "${ROCM_HOME}/lib/cmake/amd_comgr" "${ROCM_HOME}/lib/cmake/hsa-runtime64" CACHE PATH "Path to which HIP has been installed")
26+
message(WARNING "System variable HIP_CMAKE_PATH is nonexist, defaults to ${HIP_CMAKE_PATH}")
27+
28+
set(CMAKE_PREFIX_PATH "${ROCM_HOME};${CMAKE_PREFIX_PATH}")
29+
else()
30+
set(HIP_CMAKE_PATH $ENV{HIP_CMAKE_PATH} CACHE PATH "Path to which HIP has been installed")
31+
endif()
32+
endif()
33+
34+
set(CMAKE_MODULE_PATH "${HIP_CMAKE_PATH}" ${CMAKE_MODULE_PATH})
35+
36+
##### Flash infer project
37+
project(flashinfer C CXX)
38+
39+
# set CMAKE_CXX_COMPILER to hipcc
40+
# set(CMAKE_FIND_DEBUG_MODE TRUE)
41+
find_package(HIP QUIET)
42+
if(HIP_FOUND)
43+
message(STATUS "Found HIP: " ${HIP_VERSION})
44+
execute_process(COMMAND bash -c "/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*'"
45+
OUTPUT_VARIABLE CMAKE_HIP_ARCHITECTURES OUTPUT_STRIP_TRAILING_WHITESPACE)
46+
47+
enable_language(HIP)
48+
49+
add_definitions(-DUSE_ROCM)
50+
else()
51+
message(WARNING "Could not find HIP. Ensure that ROCM SDK is either installed in /opt/rocm or the variable HIP_CMAKE_PATH is set to point to the right location.")
52+
endif()
53+
54+
find_package(CUDA QUIET)
55+
if (CUDA_FOUND)
56+
message(STATUS "FOUND CUDA: " ${CUDA_TOOLKIT_ROOT_DIR})
57+
else()
58+
message(WARNING "Could not find CUDA.")
59+
endif()
60+
61+
if (NOT (HIP_FOUND) AND NOT (CUDA_FOUND))
62+
message(FATAL "ROCM/CUDA SDK must be supported")
63+
endif()
364

465
include(cmake/utils/Utils.cmake)
566

667
set(CMAKE_CXX_STANDARD 17)
768
set(CMAKE_CUDA_STANDARD 17)
69+
set(CMAKE_HIP_STANDARD 17)
870

971
if(EXISTS ${CMAKE_BINARY_DIR}/config.cmake)
1072
include(${CMAKE_BINARY_DIR}/config.cmake)
@@ -45,23 +107,41 @@ flashinfer_option(FLASHINFER_GEN_POS_ENCODING_MODES "Pos encodings to enable" 0
45107
flashinfer_option(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "QK reductions to enable" "false" "true")
46108
flashinfer_option(FLASHINFER_GEN_MASK_MODES "Mask modes to enable" 0 1 2)
47109

110+
# ROCM ARCH
111+
if(DEFINED CMAKE_HIP_ARCHITECTURES)
112+
message(STATUS "CMAKE_HIP_ARCHITECTURES : ${CMAKE_HIP_ARCHITECTURES}")
113+
114+
else(CMAKE_HIP_ARCHITECTURES)
115+
116+
# CUDA ARCH
48117
if(DEFINED FLASHINFER_CUDA_ARCHITECTURES)
49-
message(STATUS "CMAKE_CUDA_ARCHITECTURES set to ${FLASHINFER_CUDA_ARCHITECTURES}.")
118+
message(STATUS "CMAKE_CUDA_ARCHITECTURES set to
119+
${FLASHINFER_CUDA_ARCHITECTURES}.")
50120
set(CMAKE_CUDA_ARCHITECTURES ${FLASHINFER_CUDA_ARCHITECTURES})
51121
else(DEFINED FLASHINFER_CUDA_ARCHITECTURES)
52122
message(STATUS "CMAKE_CUDA_ARCHITECTURES is ${CMAKE_CUDA_ARCHITECTURES}")
53123
endif(DEFINED FLASHINFER_CUDA_ARCHITECTURES)
54124

125+
endif()
126+
55127
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
128+
list(APPEND CMAKE_MODULE_PATH "${ROCM_HOME}/lib/cmake/hip")
56129
if(FLASHINFER_PREFILL OR FLASHINFER_DECODE OR FLASHINFER_PAGE OR FLASHINFER_CASCADE OR FLASHINFER_SAMPLING OR FLASHINFER_NORM)
57130
message(STATUS "NVBench and GoogleTest enabled")
58-
add_subdirectory(3rdparty/nvbench)
59-
if(FLASHINFER_DISTRIBUTED)
131+
if (HIP_FOUND)
132+
add_subdirectory(3rdparty/hipbench)
133+
else()
134+
add_subdirectory(3rdparty/nvbench)
135+
endif()
136+
if (FLASHINFER_DISTRIBUTED)
137+
message(STATUS "compiling 3rdparty/mscclpp ...")
60138
add_subdirectory(3rdparty/mscclpp)
61139
else(FLASHINFER_DISTRIBUTED)
62140
add_subdirectory(3rdparty/googletest)
63141
endif(FLASHINFER_DISTRIBUTED)
64142
endif(FLASHINFER_PREFILL OR FLASHINFER_DECODE OR FLASHINFER_PAGE OR FLASHINFER_CASCADE OR FLASHINFER_SAMPLING OR FLASHINFER_NORM)
143+
144+
# fixed with rocm path
65145
find_package(Thrust REQUIRED)
66146

67147
set(
@@ -77,6 +157,8 @@ endif(FLASHINFER_ENABLE_FP8)
77157
if(FLASHINFER_ENABLE_BF16)
78158
message(STATUS "Compile bf16 kernels.")
79159
add_definitions(-DFLASHINFER_ENABLE_BF16)
160+
else()
161+
message (WARNING "Since bf16 is not enabled, many tests will be disabled.")
80162
endif(FLASHINFER_ENABLE_BF16)
81163

82164
# generate kernel inst
@@ -189,6 +271,9 @@ endforeach(head_dim)
189271
add_library(decode_kernels STATIC ${single_decode_kernels_src} ${batch_decode_kernels_src})
190272
target_include_directories(decode_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR})
191273
target_compile_options(decode_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options -compress-all)
274+
if (HIP_FOUND)
275+
set_target_properties(decode_kernels PROPERTIES LINKER_LANGUAGE HIP)
276+
endif()
192277

193278
# single prefill kernel inst generation
194279
foreach(head_dim IN LISTS HEAD_DIMS)
@@ -302,6 +387,9 @@ endforeach(head_dim)
302387
add_library(prefill_kernels STATIC ${single_prefill_kernels_src} ${batch_paged_prefill_kernels_src} ${batch_ragged_prefill_kernels_src})
303388
target_include_directories(prefill_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR})
304389
target_compile_options(prefill_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options -compress-all)
390+
if (HIP_FOUND)
391+
set_target_properties(prefill_kernels PROPERTIES LINKER_LANGUAGE HIP)
392+
endif()
305393

306394
if (FLASHINFER_DECODE)
307395
message(STATUS "Compile single decode kernel benchmarks.")
@@ -315,6 +403,7 @@ if (FLASHINFER_DECODE)
315403

316404
message(STATUS "Compile single decode kernel tests.")
317405
file(GLOB_RECURSE TEST_DECODE_SRCS ${PROJECT_SOURCE_DIR}/src/test_single_decode.cu)
406+
message(STATUS "test source : ${TEST_DECODE_SRCS}")
318407
add_executable(test_single_decode ${TEST_DECODE_SRCS})
319408
target_include_directories(test_single_decode PRIVATE ${FLASHINFER_INCLUDE_DIR})
320409
target_include_directories(test_single_decode PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
@@ -339,6 +428,13 @@ if (FLASHINFER_DECODE)
339428
add_dependencies(test_batch_decode dispatch_inc)
340429
target_link_libraries(test_batch_decode PRIVATE gtest gtest_main decode_kernels)
341430
target_compile_options(test_batch_decode PRIVATE -Wno-switch-bool)
431+
432+
if (HIP_FOUND)
433+
set_target_properties(bench_single_decode PROPERTIES LINKER_LANGUAGE HIP)
434+
set_target_properties(test_single_decode PROPERTIES LINKER_LANGUAGE HIP)
435+
set_target_properties(bench_batch_decode PROPERTIES LINKER_LANGUAGE HIP)
436+
set_target_properties(test_batch_decode PROPERTIES LINKER_LANGUAGE HIP)
437+
endif()
342438
endif(FLASHINFER_DECODE)
343439

344440
if (FLASHINFER_PREFILL)
@@ -377,6 +473,13 @@ if (FLASHINFER_PREFILL)
377473
add_dependencies(test_batch_prefill dispatch_inc)
378474
target_link_libraries(test_batch_prefill PRIVATE gtest gtest_main prefill_kernels)
379475
target_compile_options(test_batch_prefill PRIVATE -Wno-switch-bool)
476+
477+
if (HIP_FOUND)
478+
set_target_properties(bench_single_prefill PROPERTIES LINKER_LANGUAGE HIP)
479+
set_target_properties(test_single_prefill PROPERTIES LINKER_LANGUAGE HIP)
480+
set_target_properties(bench_batch_prefill PROPERTIES LINKER_LANGUAGE HIP)
481+
set_target_properties(test_batch_prefill PROPERTIES LINKER_LANGUAGE HIP)
482+
endif()
380483
endif(FLASHINFER_PREFILL)
381484

382485
if (FLASHINFER_PAGE)
@@ -387,6 +490,10 @@ if (FLASHINFER_PAGE)
387490
target_include_directories(test_page PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
388491
target_link_libraries(test_page PRIVATE gtest gtest_main)
389492
target_compile_options(test_page PRIVATE -Wno-switch-bool)
493+
494+
if (HIP_FOUND)
495+
set_target_properties(test_page PROPERTIES LINKER_LANGUAGE HIP)
496+
endif()
390497
endif(FLASHINFER_PAGE)
391498

392499
if (FLASHINFER_CASCADE)
@@ -407,6 +514,10 @@ if (FLASHINFER_CASCADE)
407514
add_dependencies(test_cascade dispatch_inc)
408515
target_link_libraries(test_cascade PRIVATE gtest gtest_main decode_kernels prefill_kernels)
409516
target_compile_options(test_cascade PRIVATE -Wno-switch-bool)
517+
518+
if (HIP_FOUND)
519+
set_target_properties(test_cascade PROPERTIES LINKER_LANGUAGE HIP)
520+
endif()
410521
endif(FLASHINFER_CASCADE)
411522

412523
if (FLASHINFER_SAMPLING)
@@ -425,27 +536,52 @@ if (FLASHINFER_SAMPLING)
425536
target_include_directories(test_sampling PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
426537
target_link_libraries(test_sampling PRIVATE gtest gtest_main)
427538
target_compile_options(test_sampling PRIVATE -Wno-switch-bool)
539+
540+
if (HIP_FOUND)
541+
set_target_properties(bench_sampling PROPERTIES LINKER_LANGUAGE HIP)
542+
set_target_properties(test_sampling PROPERTIES LINKER_LANGUAGE HIP)
543+
endif()
428544
endif(FLASHINFER_SAMPLING)
429545

430-
if (FLASHINFER_NORM)
546+
if (TRUE)#(FLASHINFER_NORM) TODO(yiakwy) : fix options
431547
message(STATUS "Compile normalization kernel benchmarks.")
432548
file(GLOB_RECURSE BENCH_NORM_SRCS ${PROJECT_SOURCE_DIR}/src/bench_norm.cu)
433-
add_executable(bench_norm ${BENCH_NORM_SRCS})
549+
550+
if (HIP_FOUND)
551+
set_source_files_properties(${BENCH_NORM_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
552+
hip_add_executable(bench_norm ${BENCH_NORM_SRCS})
553+
target_include_directories(bench_norm PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/hipbench)
554+
else(HIP_FOUND)
555+
add_executable(bench_norm ${BENCH_NORM_SRCS})
556+
target_include_directories(bench_norm PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nvbench)
557+
endif()
558+
434559
target_include_directories(bench_norm PRIVATE ${FLASHINFER_INCLUDE_DIR})
435-
target_include_directories(bench_norm PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nvbench)
436560
target_link_libraries(bench_norm PRIVATE nvbench::main)
437561
target_compile_options(bench_norm PRIVATE -Wno-switch-bool)
438562

439563
message(STATUS "Compile normalization kernel tests.")
440564
file(GLOB_RECURSE TEST_NORM_SRCS ${PROJECT_SOURCE_DIR}/src/test_norm.cu)
441-
add_executable(test_norm ${TEST_NORM_SRCS})
565+
566+
if (HIP_FOUND)
567+
set_source_files_properties(${TEST_NORM_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
568+
hip_add_executable(test_norm ${TEST_NORM_SRCS})
569+
else(HIP_FOUND)
570+
add_executable(test_norm ${TEST_NORM_SRCS})
571+
endif()
572+
442573
target_include_directories(test_norm PRIVATE ${FLASHINFER_INCLUDE_DIR})
443574
target_include_directories(test_norm PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
444575
target_link_libraries(test_norm PRIVATE gtest gtest_main)
445576
target_compile_options(test_norm PRIVATE -Wno-switch-bool)
577+
578+
if (HIP_FOUND)
579+
set_target_properties(bench_norm PROPERTIES LINKER_LANGUAGE HIP)
580+
set_target_properties(test_norm PROPERTIES LINKER_LANGUAGE HIP)
581+
endif()
446582
endif(FLASHINFER_NORM)
447583

448-
if(FLASHINFER_TVM_BINDING)
584+
if (FLASHINFER_TVM_BINDING)
449585
message(STATUS "Compile tvm binding.")
450586
if(NOT FLASHINFER_TVM_SOURCE_DIR STREQUAL "")
451587
set(TVM_SOURCE_DIR_SET ${FLASHINFER_TVM_SOURCE_DIR})
@@ -477,6 +613,10 @@ if(FLASHINFER_FASTDIV_TEST)
477613
target_include_directories(test_fastdiv PRIVATE ${FLASHINFER_INCLUDE_DIR})
478614
target_include_directories(test_fastdiv PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
479615
target_link_libraries(test_fastdiv PRIVATE gtest gtest_main)
616+
617+
if (HIP_FOUND)
618+
set_target_properties(test_fastdiv PROPERTIES LINKER_LANGUAGE HIP)
619+
endif()
480620
endif(FLASHINFER_FASTDIV_TEST)
481621

482622
if(FLASHINFER_FASTDEQUANT_TEST)
@@ -486,9 +626,11 @@ if(FLASHINFER_FASTDEQUANT_TEST)
486626
target_include_directories(test_fast_dequant PRIVATE ${FLASHINFER_INCLUDE_DIR})
487627
target_include_directories(test_fast_dequant PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
488628
target_link_libraries(test_fast_dequant PRIVATE gtest gtest_main)
489-
endif(FLASHINFER_FASTDEQUANT_TEST)
490-
491629

630+
if (HIP_FOUND)
631+
set_target_properties(test_fast_dequant PROPERTIES LINKER_LANGUAGE HIP)
632+
endif()
633+
endif(FLASHINFER_FASTDEQUANT_TEST)
492634

493635
if (FLASHINFER_DISTRIBUTED)
494636
find_package(MPI REQUIRED)
@@ -506,4 +648,9 @@ if (FLASHINFER_DISTRIBUTED)
506648
target_include_directories(test_attn_all_reduce PRIVATE ${FLASHINFER_INCLUDE_DIR} 3rdparty/mscclpp/include 3rdparty/spdlog/include)
507649
target_link_libraries(test_attn_all_reduce PRIVATE MPI::MPI_CXX mscclpp)
508650
target_compile_definitions(test_attn_all_reduce PRIVATE -DENABLE_MPI)
651+
652+
if (HIP_FOUND)
653+
set_target_properties(test_sum_all_reduce PROPERTIES LINKER_LANGUAGE HIP)
654+
set_target_properties(test_attn_all_reduce PROPERTIES LINKER_LANGUAGE HIP)
655+
endif()
509656
endif(FLASHINFER_DISTRIBUTED)

cmake/config.cmake

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ set(FLASHINFER_GEN_MASK_MODES 0 1 2)
4040
# So it's recommended to set it to a specific value if you know the architecture of the target GPU.
4141
# Example:
4242
# set(FLASHINFER_CUDA_ARCHITECTURES 80)
43-
set(FLASHINFER_CUDA_ARCHITECTURES native)
43+
set(FLASHINFER_CUDA_ARCHITECTURES native)

cmake/modules/FindThrust.cmake

+2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ find_path( THRUST_INCLUDE_DIR
3333
/usr/include/cuda
3434
/usr/local/include
3535
/usr/local/cuda/include
36+
/opt/rocm/include
3637
${CUDA_INCLUDE_DIRS}
38+
${HIP_INCLUDE_DIRS}
3739
NAMES thrust/version.h
3840
DOC "Thrust headers"
3941
)

cmake/utils/Utils.cmake

+4
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,18 @@ macro(flashinfer_option variable description value)
3636
if("${__value}" MATCHES ";")
3737
# list values directly pass through
3838
__flashinfer_option(${variable} "${description}" "${__value}")
39+
message(STATUS "1 : creating ${variable} option, description : ${description}, value : ${__value}")
3940
elseif(DEFINED ${__value})
4041
if(${__value})
4142
__flashinfer_option(${variable} "${description}" ON)
43+
message(STATUS "2 : creating ${variable} option, description : ${description}, value : ON")
4244
else()
4345
__flashinfer_option(${variable} "${description}" OFF)
46+
message(STATUS "3 : creating ${variable} option, description : ${description}, value : OFF")
4447
endif()
4548
else()
4649
__flashinfer_option(${variable} "${description}" "${__value}")
50+
message(STATUS "4 : creating ${variable} option, description : ${description}, value : ${__value}")
4751
endif()
4852
else()
4953
unset(${variable} CACHE)

0 commit comments

Comments
 (0)