@@ -260,16 +260,27 @@ __configure_fbgemm_gpu_build_cuda () {
260260 # https://github.com/NVIDIA/nvbench/discussions/129
261261 # https://github.com/vllm-project/vllm/blob/main/CMakeLists.txt#L187
262262 # https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp#L224
263+
264+ # NOTE: It turns out that the order of the arch_list matters, and that
265+ # appending 7.0/7.5 to the back of the list mysteriously results in
266+ # undefined symbol errors on .SO loads
267+ if [[ $fbgemm_build_target == " hstu" ]]; then
268+ # HSTU requires sm_75 or higher
269+ local arch_list=" 7.5"
270+ else
271+ local arch_list=" 7.0"
272+ fi
273+
263274 if [[ $cuda_version_nvcc == * " V12.8" * ]]; then
264- local arch_list=" 7.0 ;8.0;9.0a;10.0a;12.0a"
275+ local arch_list=" ${arch_list} ;8.0;9.0a;10.0a;12.0a"
265276
266277 elif [[ $cuda_version_nvcc == * " V12.6" * ]] ||
267278 [[ $cuda_version_nvcc == * " V12.4" * ]] ||
268279 [[ $cuda_version_nvcc == * " V12.1" * ]]; then
269- local arch_list=" 7.0 ;8.0;9.0a"
280+ local arch_list=" ${arch_list} ;8.0;9.0a"
270281
271282 else
272- local arch_list=" 7.0 ;8.0;9.0"
283+ local arch_list=" ${arch_list} ;8.0;9.0"
273284 fi
274285 fi
275286 echo " [BUILD] Setting the following CUDA targets: ${arch_list} "
@@ -474,31 +485,29 @@ __build_fbgemm_gpu_common_pre_steps () {
474485 # Private function that uses variables instantiated by its caller
475486
476487 # Check C/C++ compilers are visible (the build scripts look specifically for `gcc`)
477- (test_binpath " ${env_name} " cc) || return 1
478- (test_binpath " ${env_name} " gcc) || return 1
479- (test_binpath " ${env_name} " c++) || return 1
480- (test_binpath " ${env_name} " g++) || return 1
488+ (test_binpath " ${env_name} " cc) || return 1
489+ (test_binpath " ${env_name} " gcc) || return 1
490+ (test_binpath " ${env_name} " c++) || return 1
491+ (test_binpath " ${env_name} " g++) || return 1
481492
482493 # Set the default the FBGEMM build variant to be default (i.e. FBGEMM_GPU)
483- if [ " $fbgemm_build_target " != " genai " ] &&
484- [ " $fbgemm_build_target " != " default " ]; then
494+ # shellcheck disable=SC2076
495+ if [[ ! " genai hstu default " =~ " $fbgemm_build_target " ] ]; then
485496 echo " ################################################################################"
486497 echo " [BUILD] Unknown FBGEMM build TARGET: ${fbgemm_build_target} "
487- echo " [BUILD] Defaulting to 'default' "
498+ echo " [BUILD] Exiting ... "
488499 echo " ################################################################################"
489- export fbgemm_build_target= " default "
500+ return 1
490501 fi
491502
492503 # Set the default the FBGEMM build variant to be CUDA
493- if [ " $fbgemm_build_variant " != " docs" ] &&
494- [ " $fbgemm_build_variant " != " cpu" ] &&
495- [ " $fbgemm_build_variant " != " cuda" ] &&
496- [ " $fbgemm_build_variant " != " rocm" ]; then
504+ # shellcheck disable=SC2076
505+ if [[ ! " docs cpu cuda rocm " =~ " $fbgemm_build_variant " ]]; then
497506 echo " ################################################################################"
498507 echo " [BUILD] Unknown FBGEMM build VARIANT: ${fbgemm_build_variant} "
499- echo " [BUILD] Defaulting to CUDA "
508+ echo " [BUILD] Exiting ... "
500509 echo " ################################################################################"
501- export fbgemm_build_variant= " cuda "
510+ return 1
502511 fi
503512
504513 # Extract and set the Python tag
@@ -603,6 +612,11 @@ __verify_library_symbols () {
603612 )
604613 fi
605614
615+ elif [ " ${fbgemm_build_target} " == " hstu" ]; then
616+ local lib_symbols_to_check=(
617+ fbgemm_gpu::hstu::set_params_fprop
618+ )
619+
606620 else
607621 local lib_symbols_to_check=(
608622 fbgemm_gpu::asynchronous_inclusive_cumsum_cpu
0 commit comments