1
- cmake_minimum_required (VERSION 3.23.1)
2
- project (flashinfer CUDA CXX)
1
+ cmake_minimum_required (VERSION 3.26.4)
2
+
3
+ # set compiler conditional
4
+ # Verified for ROCM >= 6.2, alias to $hip_LIB_INSTALL_DIR defined in ${ROCM_HOME}/lib/cmake/hip/hip-config-amd.cmake
5
+ set (ROCM_HOME "/opt/rocm" CACHE PATH "ROCM SDK INSTALLATION HOME" )
6
+ if (NOT IS_DIRECTORY ${ROCM_HOME} )
7
+ message (WARNING "ROCM_HOME ${ROCM_HOME} is not a directory" )
8
+ endif ()
9
+
10
+ if (LINUX)
11
+ # SDK Root in CMAKE config file; LINUX system defaults to ENV{ROCM_PATH}; WIN32 system defaults to ENV{HIP_PATH}
12
+ set (ENV{ROCM_PATH} ${ROCM_HOME} )
13
+ endif ()
14
+
15
+ if (NOT DEFINED HIP_CMAKE_PATH)
16
+ if (NOT DEFINED ENV{HIP_CMAKE_PATH})
17
+ # NOTE(yiakwy) : find_package(HIP) will first search for
18
+ # cmake/Modules/FindAMDDeviceLibs.cmake
19
+ # , then
20
+ # /opt/rocm/lib/cmake/AMDDeviceLibs/AMDDeviceLibsConfig.cmake
21
+ # this will add hip::host, hip::device dependencies to be linked by any hip targets (ROCM >= 6.x).
22
+ # Add hip-config.cmake to CMake module search path.
23
+ # set(HIP_CMAKE_PATH "${ROCM_HOME}/share/rocm/cmake" "${ROCM_HOME}/share/rocmcmakebuildtools/cmake/" CACHE PATH "Path to which HIP has been installed")
24
+ # NOTE(yiakwy) : adding ${ROCM_HOME}/lib/cmake/hip has conflicts with 3rdparty/mscclpp
25
+ set (HIP_CMAKE_PATH "${ROCM_HOME} /lib/cmake/AMDDeviceLibs" "${ROCM_HOME} /lib/cmake/amd_comgr" "${ROCM_HOME} /lib/cmake/hsa-runtime64" CACHE PATH "Path to which HIP has been installed" )
26
+ message (WARNING "System variable HIP_CMAKE_PATH is nonexist, defaults to ${HIP_CMAKE_PATH} " )
27
+
28
+ set (CMAKE_PREFIX_PATH "${ROCM_HOME} ;${CMAKE_PREFIX_PATH} " )
29
+ else ()
30
+ set (HIP_CMAKE_PATH $ENV{HIP_CMAKE_PATH} CACHE PATH "Path to which HIP has been installed" )
31
+ endif ()
32
+ endif ()
33
+
34
+ set (CMAKE_MODULE_PATH "${HIP_CMAKE_PATH} " ${CMAKE_MODULE_PATH} )
35
+
36
+ ##### Flash infer project
37
+ project (flashinfer C CXX)
38
+
39
+ # set CMAKE_CXX_COMPILER to hipcc
40
+ # set(CMAKE_FIND_DEBUG_MODE TRUE)
41
+ find_package (HIP QUIET )
42
+ if (HIP_FOUND)
43
+ message (STATUS "Found HIP: " ${HIP_VERSION} )
44
+ execute_process (COMMAND bash -c "/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*'"
45
+ OUTPUT_VARIABLE CMAKE_HIP_ARCHITECTURES OUTPUT_STRIP_TRAILING_WHITESPACE)
46
+
47
+ enable_language (HIP)
48
+
49
+ add_definitions (-DUSE_ROCM)
50
+ else ()
51
+ message (WARNING "Could not find HIP. Ensure that ROCM SDK is either installed in /opt/rocm or the variable HIP_CMAKE_PATH is set to point to the right location." )
52
+ endif ()
53
+
54
+ find_package (CUDA QUIET )
55
+ if (CUDA_FOUND)
56
+ message (STATUS "FOUND CUDA: " ${CUDA_TOOLKIT_ROOT_DIR} )
57
+ else ()
58
+ message (WARNING "Could not find CUDA." )
59
+ endif ()
60
+
61
+ if (NOT (HIP_FOUND) AND NOT (CUDA_FOUND))
62
+ message (FATAL "ROCM/CUDA SDK must be supported" )
63
+ endif ()
3
64
4
65
include (cmake/utils/Utils.cmake)
5
66
6
67
set (CMAKE_CXX_STANDARD 17)
7
68
set (CMAKE_CUDA_STANDARD 17)
69
+ set (CMAKE_HIP_STANDARD 17)
8
70
9
71
if (EXISTS ${CMAKE_BINARY_DIR} /config.cmake)
10
72
include (${CMAKE_BINARY_DIR} /config.cmake)
@@ -45,23 +107,41 @@ flashinfer_option(FLASHINFER_GEN_POS_ENCODING_MODES "Pos encodings to enable" 0
45
107
flashinfer_option(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "QK reductions to enable" "false" "true" )
46
108
flashinfer_option(FLASHINFER_GEN_MASK_MODES "Mask modes to enable" 0 1 2)
47
109
110
+ # ROCM ARCH
111
+ if (DEFINED CMAKE_HIP_ARCHITECTURES)
112
+ message (STATUS "CMAKE_HIP_ARCHITECTURES : ${CMAKE_HIP_ARCHITECTURES} " )
113
+
114
+ else (CMAKE_HIP_ARCHITECTURES)
115
+
116
+ # CUDA ARCH
48
117
if (DEFINED FLASHINFER_CUDA_ARCHITECTURES)
49
- message (STATUS "CMAKE_CUDA_ARCHITECTURES set to ${FLASHINFER_CUDA_ARCHITECTURES} ." )
118
+ message (STATUS "CMAKE_CUDA_ARCHITECTURES set to
119
+ ${FLASHINFER_CUDA_ARCHITECTURES} ." )
50
120
set (CMAKE_CUDA_ARCHITECTURES ${FLASHINFER_CUDA_ARCHITECTURES} )
51
121
else (DEFINED FLASHINFER_CUDA_ARCHITECTURES)
52
122
message (STATUS "CMAKE_CUDA_ARCHITECTURES is ${CMAKE_CUDA_ARCHITECTURES} " )
53
123
endif (DEFINED FLASHINFER_CUDA_ARCHITECTURES)
54
124
125
+ endif ()
126
+
55
127
list (APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR} /cmake/modules" )
128
+ list (APPEND CMAKE_MODULE_PATH "${ROCM_HOME} /lib/cmake/hip" )
56
129
if (FLASHINFER_PREFILL OR FLASHINFER_DECODE OR FLASHINFER_PAGE OR FLASHINFER_CASCADE OR FLASHINFER_SAMPLING OR FLASHINFER_NORM)
57
130
message (STATUS "NVBench and GoogleTest enabled" )
58
- add_subdirectory (3rdparty/nvbench)
59
- if (FLASHINFER_DISTRIBUTED)
131
+ if (HIP_FOUND)
132
+ add_subdirectory (3rdparty/hipbench)
133
+ else ()
134
+ add_subdirectory (3rdparty/nvbench)
135
+ endif ()
136
+ if (FLASHINFER_DISTRIBUTED)
137
+ message (STATUS "compiling 3rdparty/mscclpp ..." )
60
138
add_subdirectory (3rdparty/mscclpp)
61
139
else (FLASHINFER_DISTRIBUTED)
62
140
add_subdirectory (3rdparty/googletest)
63
141
endif (FLASHINFER_DISTRIBUTED)
64
142
endif (FLASHINFER_PREFILL OR FLASHINFER_DECODE OR FLASHINFER_PAGE OR FLASHINFER_CASCADE OR FLASHINFER_SAMPLING OR FLASHINFER_NORM)
143
+
144
+ # fixed with rocm path
65
145
find_package (Thrust REQUIRED)
66
146
67
147
set (
@@ -77,6 +157,8 @@ endif(FLASHINFER_ENABLE_FP8)
77
157
if (FLASHINFER_ENABLE_BF16)
78
158
message (STATUS "Compile bf16 kernels." )
79
159
add_definitions (-DFLASHINFER_ENABLE_BF16)
160
+ else ()
161
+ message (WARNING "Since bf16 is not enabled, many tests will be disabled." )
80
162
endif (FLASHINFER_ENABLE_BF16)
81
163
82
164
# generate kernel inst
@@ -189,6 +271,9 @@ endforeach(head_dim)
189
271
add_library (decode_kernels STATIC ${single_decode_kernels_src} ${batch_decode_kernels_src} )
190
272
target_include_directories (decode_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR} )
191
273
target_compile_options (decode_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options -compress-all )
274
+ if (HIP_FOUND)
275
+ set_target_properties (decode_kernels PROPERTIES LINKER_LANGUAGE HIP)
276
+ endif ()
192
277
193
278
# single prefill kernel inst generation
194
279
foreach (head_dim IN LISTS HEAD_DIMS)
@@ -302,6 +387,9 @@ endforeach(head_dim)
302
387
add_library (prefill_kernels STATIC ${single_prefill_kernels_src} ${batch_paged_prefill_kernels_src} ${batch_ragged_prefill_kernels_src} )
303
388
target_include_directories (prefill_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR} )
304
389
target_compile_options (prefill_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options -compress-all )
390
+ if (HIP_FOUND)
391
+ set_target_properties (prefill_kernels PROPERTIES LINKER_LANGUAGE HIP)
392
+ endif ()
305
393
306
394
if (FLASHINFER_DECODE)
307
395
message (STATUS "Compile single decode kernel benchmarks." )
@@ -315,6 +403,7 @@ if (FLASHINFER_DECODE)
315
403
316
404
message (STATUS "Compile single decode kernel tests." )
317
405
file (GLOB_RECURSE TEST_DECODE_SRCS ${PROJECT_SOURCE_DIR} /src/test_single_decode.cu)
406
+ message (STATUS "test source : ${TEST_DECODE_SRCS} " )
318
407
add_executable (test_single_decode ${TEST_DECODE_SRCS} )
319
408
target_include_directories (test_single_decode PRIVATE ${FLASHINFER_INCLUDE_DIR} )
320
409
target_include_directories (test_single_decode PRIVATE ${gtest_SOURCE_DIR} /include ${gtest_SOURCE_DIR} )
@@ -339,6 +428,13 @@ if (FLASHINFER_DECODE)
339
428
add_dependencies (test_batch_decode dispatch_inc)
340
429
target_link_libraries (test_batch_decode PRIVATE gtest gtest_main decode_kernels)
341
430
target_compile_options (test_batch_decode PRIVATE -Wno-switch-bool )
431
+
432
+ if (HIP_FOUND)
433
+ set_target_properties (bench_single_decode PROPERTIES LINKER_LANGUAGE HIP)
434
+ set_target_properties (test_single_decode PROPERTIES LINKER_LANGUAGE HIP)
435
+ set_target_properties (bench_batch_decode PROPERTIES LINKER_LANGUAGE HIP)
436
+ set_target_properties (test_batch_decode PROPERTIES LINKER_LANGUAGE HIP)
437
+ endif ()
342
438
endif (FLASHINFER_DECODE)
343
439
344
440
if (FLASHINFER_PREFILL)
@@ -377,6 +473,13 @@ if (FLASHINFER_PREFILL)
377
473
add_dependencies (test_batch_prefill dispatch_inc)
378
474
target_link_libraries (test_batch_prefill PRIVATE gtest gtest_main prefill_kernels)
379
475
target_compile_options (test_batch_prefill PRIVATE -Wno-switch-bool )
476
+
477
+ if (HIP_FOUND)
478
+ set_target_properties (bench_single_prefill PROPERTIES LINKER_LANGUAGE HIP)
479
+ set_target_properties (test_single_prefill PROPERTIES LINKER_LANGUAGE HIP)
480
+ set_target_properties (bench_batch_prefill PROPERTIES LINKER_LANGUAGE HIP)
481
+ set_target_properties (test_batch_prefill PROPERTIES LINKER_LANGUAGE HIP)
482
+ endif ()
380
483
endif (FLASHINFER_PREFILL)
381
484
382
485
if (FLASHINFER_PAGE)
@@ -387,6 +490,10 @@ if (FLASHINFER_PAGE)
387
490
target_include_directories (test_page PRIVATE ${gtest_SOURCE_DIR} /include ${gtest_SOURCE_DIR} )
388
491
target_link_libraries (test_page PRIVATE gtest gtest_main)
389
492
target_compile_options (test_page PRIVATE -Wno-switch-bool )
493
+
494
+ if (HIP_FOUND)
495
+ set_target_properties (test_page PROPERTIES LINKER_LANGUAGE HIP)
496
+ endif ()
390
497
endif (FLASHINFER_PAGE)
391
498
392
499
if (FLASHINFER_CASCADE)
@@ -407,6 +514,10 @@ if (FLASHINFER_CASCADE)
407
514
add_dependencies (test_cascade dispatch_inc)
408
515
target_link_libraries (test_cascade PRIVATE gtest gtest_main decode_kernels prefill_kernels)
409
516
target_compile_options (test_cascade PRIVATE -Wno-switch-bool )
517
+
518
+ if (HIP_FOUND)
519
+ set_target_properties (test_cascade PROPERTIES LINKER_LANGUAGE HIP)
520
+ endif ()
410
521
endif (FLASHINFER_CASCADE)
411
522
412
523
if (FLASHINFER_SAMPLING)
@@ -425,27 +536,52 @@ if (FLASHINFER_SAMPLING)
425
536
target_include_directories (test_sampling PRIVATE ${gtest_SOURCE_DIR} /include ${gtest_SOURCE_DIR} )
426
537
target_link_libraries (test_sampling PRIVATE gtest gtest_main)
427
538
target_compile_options (test_sampling PRIVATE -Wno-switch-bool )
539
+
540
+ if (HIP_FOUND)
541
+ set_target_properties (bench_sampling PROPERTIES LINKER_LANGUAGE HIP)
542
+ set_target_properties (test_sampling PROPERTIES LINKER_LANGUAGE HIP)
543
+ endif ()
428
544
endif (FLASHINFER_SAMPLING)
429
545
430
- if (FLASHINFER_NORM)
546
+ if (TRUE ) #( FLASHINFER_NORM) TODO(yiakwy) : fix options
431
547
message (STATUS "Compile normalization kernel benchmarks." )
432
548
file (GLOB_RECURSE BENCH_NORM_SRCS ${PROJECT_SOURCE_DIR} /src/bench_norm.cu)
433
- add_executable (bench_norm ${BENCH_NORM_SRCS} )
549
+
550
+ if (HIP_FOUND)
551
+ set_source_files_properties (${BENCH_NORM_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
552
+ hip_add_executable(bench_norm ${BENCH_NORM_SRCS} )
553
+ target_include_directories (bench_norm PRIVATE ${PROJECT_SOURCE_DIR} /3rdparty/hipbench)
554
+ else (HIP_FOUND)
555
+ add_executable (bench_norm ${BENCH_NORM_SRCS} )
556
+ target_include_directories (bench_norm PRIVATE ${PROJECT_SOURCE_DIR} /3rdparty/nvbench)
557
+ endif ()
558
+
434
559
target_include_directories (bench_norm PRIVATE ${FLASHINFER_INCLUDE_DIR} )
435
- target_include_directories (bench_norm PRIVATE ${PROJECT_SOURCE_DIR} /3rdparty/nvbench)
436
560
target_link_libraries (bench_norm PRIVATE nvbench::main)
437
561
target_compile_options (bench_norm PRIVATE -Wno-switch-bool )
438
562
439
563
message (STATUS "Compile normalization kernel tests." )
440
564
file (GLOB_RECURSE TEST_NORM_SRCS ${PROJECT_SOURCE_DIR} /src/test_norm.cu)
441
- add_executable (test_norm ${TEST_NORM_SRCS} )
565
+
566
+ if (HIP_FOUND)
567
+ set_source_files_properties (${TEST_NORM_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
568
+ hip_add_executable(test_norm ${TEST_NORM_SRCS} )
569
+ else (HIP_FOUND)
570
+ add_executable (test_norm ${TEST_NORM_SRCS} )
571
+ endif ()
572
+
442
573
target_include_directories (test_norm PRIVATE ${FLASHINFER_INCLUDE_DIR} )
443
574
target_include_directories (test_norm PRIVATE ${gtest_SOURCE_DIR} /include ${gtest_SOURCE_DIR} )
444
575
target_link_libraries (test_norm PRIVATE gtest gtest_main)
445
576
target_compile_options (test_norm PRIVATE -Wno-switch-bool )
577
+
578
+ if (HIP_FOUND)
579
+ set_target_properties (bench_norm PROPERTIES LINKER_LANGUAGE HIP)
580
+ set_target_properties (test_norm PROPERTIES LINKER_LANGUAGE HIP)
581
+ endif ()
446
582
endif (FLASHINFER_NORM)
447
583
448
- if (FLASHINFER_TVM_BINDING)
584
+ if (FLASHINFER_TVM_BINDING)
449
585
message (STATUS "Compile tvm binding." )
450
586
if (NOT FLASHINFER_TVM_SOURCE_DIR STREQUAL "" )
451
587
set (TVM_SOURCE_DIR_SET ${FLASHINFER_TVM_SOURCE_DIR} )
@@ -477,6 +613,10 @@ if(FLASHINFER_FASTDIV_TEST)
477
613
target_include_directories (test_fastdiv PRIVATE ${FLASHINFER_INCLUDE_DIR} )
478
614
target_include_directories (test_fastdiv PRIVATE ${gtest_SOURCE_DIR} /include ${gtest_SOURCE_DIR} )
479
615
target_link_libraries (test_fastdiv PRIVATE gtest gtest_main)
616
+
617
+ if (HIP_FOUND)
618
+ set_target_properties (test_fastdiv PROPERTIES LINKER_LANGUAGE HIP)
619
+ endif ()
480
620
endif (FLASHINFER_FASTDIV_TEST)
481
621
482
622
if (FLASHINFER_FASTDEQUANT_TEST)
@@ -486,9 +626,11 @@ if(FLASHINFER_FASTDEQUANT_TEST)
486
626
target_include_directories (test_fast_dequant PRIVATE ${FLASHINFER_INCLUDE_DIR} )
487
627
target_include_directories (test_fast_dequant PRIVATE ${gtest_SOURCE_DIR} /include ${gtest_SOURCE_DIR} )
488
628
target_link_libraries (test_fast_dequant PRIVATE gtest gtest_main)
489
- endif (FLASHINFER_FASTDEQUANT_TEST)
490
-
491
629
630
+ if (HIP_FOUND)
631
+ set_target_properties (test_fast_dequant PROPERTIES LINKER_LANGUAGE HIP)
632
+ endif ()
633
+ endif (FLASHINFER_FASTDEQUANT_TEST)
492
634
493
635
if (FLASHINFER_DISTRIBUTED)
494
636
find_package (MPI REQUIRED)
@@ -506,4 +648,9 @@ if (FLASHINFER_DISTRIBUTED)
506
648
target_include_directories (test_attn_all_reduce PRIVATE ${FLASHINFER_INCLUDE_DIR} 3rdparty/mscclpp/include 3rdparty/spdlog/include )
507
649
target_link_libraries (test_attn_all_reduce PRIVATE MPI::MPI_CXX mscclpp)
508
650
target_compile_definitions (test_attn_all_reduce PRIVATE -DENABLE_MPI)
651
+
652
+ if (HIP_FOUND)
653
+ set_target_properties (test_sum_all_reduce PROPERTIES LINKER_LANGUAGE HIP)
654
+ set_target_properties (test_attn_all_reduce PROPERTIES LINKER_LANGUAGE HIP)
655
+ endif ()
509
656
endif (FLASHINFER_DISTRIBUTED)
0 commit comments