Skip to content

Commit 2f2a1ef

Browse files
q10facebook-github-bot
authored andcommitted
Integrate HSTU into OSS CI (#4236)
Summary: X-link: facebookresearch/FBGEMM#1396 - Integrate HSTU build into OSS CI - Earlier draft of the work: #4251 Pull Request resolved: #4236 Reviewed By: ionuthristodorescu Differential Revision: D76445631 Pulled By: q10 fbshipit-source-id: 4b1eafc557b2db4480080182c2759689e0bead2f
1 parent 2b8c71d commit 2f2a1ef

File tree

10 files changed

+69
-29
lines changed

10 files changed

+69
-29
lines changed

.github/scripts/fbgemm_gpu_build.bash

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -260,16 +260,27 @@ __configure_fbgemm_gpu_build_cuda () {
260260
# https://github.com/NVIDIA/nvbench/discussions/129
261261
# https://github.com/vllm-project/vllm/blob/main/CMakeLists.txt#L187
262262
# https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp#L224
263+
264+
# NOTE: It turns out that the order of the arch_list matters, and that
265+
# appending 7.0/7.5 to the back of the list mysteriously results in
266+
# undefined symbol errors on .SO loads
267+
if [[ $fbgemm_build_target == "hstu" ]]; then
268+
# HSTU requires sm_75 or higher
269+
local arch_list="7.5"
270+
else
271+
local arch_list="7.0"
272+
fi
273+
263274
if [[ $cuda_version_nvcc == *"V12.8"* ]]; then
264-
local arch_list="7.0;8.0;9.0a;10.0a;12.0a"
275+
local arch_list="${arch_list};8.0;9.0a;10.0a;12.0a"
265276

266277
elif [[ $cuda_version_nvcc == *"V12.6"* ]] ||
267278
[[ $cuda_version_nvcc == *"V12.4"* ]] ||
268279
[[ $cuda_version_nvcc == *"V12.1"* ]]; then
269-
local arch_list="7.0;8.0;9.0a"
280+
local arch_list="${arch_list};8.0;9.0a"
270281

271282
else
272-
local arch_list="7.0;8.0;9.0"
283+
local arch_list="${arch_list};8.0;9.0"
273284
fi
274285
fi
275286
echo "[BUILD] Setting the following CUDA targets: ${arch_list}"
@@ -474,31 +485,29 @@ __build_fbgemm_gpu_common_pre_steps () {
474485
# Private function that uses variables instantiated by its caller
475486

476487
# Check C/C++ compilers are visible (the build scripts look specifically for `gcc`)
477-
(test_binpath "${env_name}" cc) || return 1
478-
(test_binpath "${env_name}" gcc) || return 1
479-
(test_binpath "${env_name}" c++) || return 1
480-
(test_binpath "${env_name}" g++) || return 1
488+
(test_binpath "${env_name}" cc) || return 1
489+
(test_binpath "${env_name}" gcc) || return 1
490+
(test_binpath "${env_name}" c++) || return 1
491+
(test_binpath "${env_name}" g++) || return 1
481492

482493
# Set the default the FBGEMM build variant to be default (i.e. FBGEMM_GPU)
483-
if [ "$fbgemm_build_target" != "genai" ] &&
484-
[ "$fbgemm_build_target" != "default" ]; then
494+
# shellcheck disable=SC2076
495+
if [[ ! " genai hstu default " =~ " $fbgemm_build_target " ]]; then
485496
echo "################################################################################"
486497
echo "[BUILD] Unknown FBGEMM build TARGET: ${fbgemm_build_target}"
487-
echo "[BUILD] Defaulting to 'default'"
498+
echo "[BUILD] Exiting ..."
488499
echo "################################################################################"
489-
export fbgemm_build_target="default"
500+
return 1
490501
fi
491502

492503
# Set the default the FBGEMM build variant to be CUDA
493-
if [ "$fbgemm_build_variant" != "docs" ] &&
494-
[ "$fbgemm_build_variant" != "cpu" ] &&
495-
[ "$fbgemm_build_variant" != "cuda" ] &&
496-
[ "$fbgemm_build_variant" != "rocm" ]; then
504+
# shellcheck disable=SC2076
505+
if [[ ! " docs cpu cuda rocm " =~ " $fbgemm_build_variant " ]]; then
497506
echo "################################################################################"
498507
echo "[BUILD] Unknown FBGEMM build VARIANT: ${fbgemm_build_variant}"
499-
echo "[BUILD] Defaulting to CUDA"
508+
echo "[BUILD] Exiting ..."
500509
echo "################################################################################"
501-
export fbgemm_build_variant="cuda"
510+
return 1
502511
fi
503512

504513
# Extract and set the Python tag
@@ -603,6 +612,11 @@ __verify_library_symbols () {
603612
)
604613
fi
605614

615+
elif [ "${fbgemm_build_target}" == "hstu" ]; then
616+
local lib_symbols_to_check=(
617+
fbgemm_gpu::hstu::set_params_fprop
618+
)
619+
606620
else
607621
local lib_symbols_to_check=(
608622
fbgemm_gpu::asynchronous_inclusive_cumsum_cpu

.github/scripts/fbgemm_gpu_install.bash

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ __install_check_subpackages () {
7777
"fbgemm_gpu.tbe.cache"
7878
)
7979

80-
if [ "$installed_fbgemm_target" != "genai" ]; then
80+
if [ "$installed_fbgemm_target" == "default" ]; then
8181
subpackages+=(
8282
"fbgemm_gpu.split_embedding_codegen_lookup_invokers"
8383
"fbgemm_gpu.tbe.ssd"
@@ -91,10 +91,13 @@ __install_check_subpackages () {
9191
}
9292

9393
__install_check_operator_registrations () {
94+
# shellcheck disable=SC2155
9495
local env_prefix=$(env_name_or_prefix "${env_name}")
9596

9697
local test_operators=()
98+
local base_import="fbgemm_gpu"
9799
echo "[INSTALL] Check for operator registrations ..."
100+
98101
if [ "$installed_fbgemm_target" == "genai" ]; then
99102
# NOTE: Currently, ROCm builds of GenAI only include quantization
100103
# operators.
@@ -115,7 +118,13 @@ __install_check_operator_registrations () {
115118
fi
116119
fi
117120

118-
else
121+
elif [ "$installed_fbgemm_target" == "hstu" ]; then
122+
test_operators+=(
123+
"torch.ops.fbgemm.hstu_varlen_bwd_80"
124+
)
125+
base_import="fbgemm_gpu.experimental.hstu"
126+
127+
elif [ "$installed_fbgemm_target" == "genai" ]; then
119128
test_operators+=(
120129
"torch.ops.fbgemm.asynchronous_inclusive_cumsum"
121130
"torch.ops.fbgemm.split_embedding_codegen_lookup_sgd_function_pt2"
@@ -124,7 +133,7 @@ __install_check_operator_registrations () {
124133

125134
for operator in "${test_operators[@]}"; do
126135
# shellcheck disable=SC2086
127-
if conda run ${env_prefix} python -c "import torch; import fbgemm_gpu; print($operator)"; then
136+
if conda run ${env_prefix} python -c "import torch; import ${base_import}; print($operator)"; then
128137
echo "[CHECK] FBGEMM_GPU operator appears to be correctly registered: $operator"
129138
else
130139
echo "################################################################################"

.github/scripts/fbgemm_gpu_test.bash

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,8 @@ __setup_fbgemm_gpu_test () {
184184
print_exec conda env config vars set ${env_prefix} KMP_DUPLICATE_LIB_OK=1
185185
fi
186186

187-
# NOTE: Uncomment to enable PyTorch C++ stacktraces
188187
# shellcheck disable=SC2086
189-
# print_exec conda env config vars set ${env_prefix} TORCH_SHOW_CPP_STACKTRACES=1
188+
print_exec conda env config vars set ${env_prefix} TORCH_SHOW_CPP_STACKTRACES=1
190189

191190
echo "[TEST] Installing PyTest ..."
192191
# shellcheck disable=SC2086
@@ -267,6 +266,11 @@ __determine_test_directories () {
267266
)
268267
fi
269268

269+
elif [ "$fbgemm_build_target" == "hstu" ]; then
270+
target_directories+=(
271+
fbgemm_gpu/experimental/hstu/test
272+
)
273+
270274
else
271275
target_directories+=(
272276
fbgemm_gpu/test

.github/workflows/fbgemm_gpu_ci_cuda.yml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# This workflow is used for FBGEMM GPU/GenAI CUDA CI as well as nightly builds
7-
# of FBGEMM GPU/GenAI CUDA against PyTorch-CUDA Nightly.
8-
name: FBGEMM GPU/GenAI CUDA CI
6+
# This workflow is used for FBGEMM GPU/GenAI/HSTU CUDA CI as well as nightly
7+
# builds of FBGEMM GPU/GenAI/HSTU CUDA against PyTorch-CUDA Nightly.
8+
name: FBGEMM GPU/GenAI/HSTU CUDA CI
99

1010
on:
1111
# PR Trigger (enabled for regression checks and debugging)
@@ -74,9 +74,13 @@ jobs:
7474
{ arch: x86, instance: "linux.24xlarge", build-target: "default", cuda-version: "11.8.0" },
7575
{ arch: x86, instance: "linux.24xlarge", build-target: "default", cuda-version: "12.6.3" },
7676
{ arch: x86, instance: "linux.24xlarge", build-target: "default", cuda-version: "12.8.0" },
77+
7778
# GenAI is unable to support 11.8.0 anymore as of https://github.com/pytorch/FBGEMM/pull/4138
7879
{ arch: x86, instance: "linux.8xlarge.memory", build-target: "genai", cuda-version: "12.6.3" },
7980
{ arch: x86, instance: "linux.8xlarge.memory", build-target: "genai", cuda-version: "12.8.0" },
81+
82+
{ arch: x86, instance: "linux.12xlarge.memory", build-target: "hstu", cuda-version: "12.6.3" },
83+
{ arch: x86, instance: "linux.12xlarge.memory", build-target: "hstu", cuda-version: "12.8.0" },
8084
]
8185
python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ]
8286
compiler: [ "gcc", "clang" ]
@@ -167,6 +171,8 @@ jobs:
167171
{ build-target: "default", cuda-version: "12.8.0" },
168172
{ build-target: "genai", cuda-version: "12.6.3" },
169173
{ build-target: "genai", cuda-version: "12.8.0" },
174+
{ build-target: "hstu", cuda-version: "12.6.3" },
175+
{ build-target: "hstu", cuda-version: "12.8.0" },
170176
]
171177
python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ]
172178
# Specify exactly ONE CUDA version for artifact publish

fbgemm_gpu/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,7 @@ if(FBGEMM_BUILD_TARGET STREQUAL BUILD_TARGET_GENAI)
283283
add_subdirectory(experimental/gemm)
284284

285285
elseif(FBGEMM_BUILD_TARGET STREQUAL BUILD_TARGET_HSTU)
286-
if(FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_CPU OR
287-
FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM)
286+
if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_CUDA)
288287
message(FATAL_ERROR
289288
"Unsupported (target, variant) combination:
290289
(${FBGEMM_BUILD_TARGET}, ${FBGEMM_BUILD_VARIANT})")

fbgemm_gpu/experimental/hstu/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# FBGEMM HSTU
22

3-
FBGEMM HSTU(Hierarchical Sequential Transduction Units)
3+
FBGEMM HSTU (Hierarchical Sequential Transduction Units)
44

55
# **1. Overview**
66

fbgemm_gpu/experimental/hstu/test/hstu_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import logging
1111
import math
12+
import os
1213
import unittest
1314
from typing import Optional, Tuple
1415

@@ -19,6 +20,8 @@
1920

2021
from hypothesis import given, settings, strategies as st, Verbosity
2122

23+
running_on_github: bool = os.getenv("GITHUB_ENV") is not None
24+
2225
logger: logging.Logger = logging.getLogger()
2326
logger.setLevel(logging.INFO)
2427

@@ -453,6 +456,9 @@ def _hstu_attention_maybe_from_cache(
453456
class HSTU16Test(unittest.TestCase):
454457
"""Test HSTU attention with float16 inputs."""
455458

459+
@unittest.skipIf(
460+
running_on_github, "GitHub runners are unable to run the test at this time"
461+
)
456462
@given(
457463
batch_size=st.sampled_from([32]),
458464
heads=st.sampled_from([2]),

fbgemm_gpu/fbgemm_gpu/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def _load_library(filename: str, no_throw: bool = False) -> None:
1717
torch.ops.load_library(os.path.join(os.path.dirname(__file__), filename))
1818
logging.info(f"Successfully loaded: '{filename}'")
1919
except Exception as error:
20-
logging.error(f"Could not load the library '{filename}': {error}")
20+
logging.error(f"Could not load the library '{filename}'!\n\n\n{error}\n\n\n")
2121
if not no_throw:
2222
raise error
2323

fbgemm_gpu/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ backports.tarfile
1414
build
1515
cmake
1616
click
17+
einops
1718
hypothesis
1819
jinja2
1920
mpmath==1.3.0

fbgemm_gpu/requirements_genai.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ backports.tarfile
1616
build
1717
cmake
1818
click
19+
einops
1920
hypothesis
2021
jinja2
2122
mpmath==1.3.0

0 commit comments

Comments
 (0)