Skip to content

Commit 644ccd7

Browse files
authored
ci: Improve compatibility with pytorch 2.5 (#711)
- Fix release workflow for pytorch 2.5 wheel by dropping python 3.8. - Relax pytorch version requirements. - Right now, install flashinfer will downgrade existing pytorch versions since it's set the version requirement to exact match.
1 parent ccd3be9 commit 644ccd7

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

.github/workflows/release_wheel.yml

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ jobs:
3333
torch: "2.2"
3434
- cuda: "12.4"
3535
torch: "2.3"
36+
- python: "3.8" # torch 2.5+ drops python 3.8
37+
torch: "2.5"
3638

3739
runs-on: [self-hosted]
3840
steps:

scripts/run-ci-build-wheel.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ else
4141
fi
4242

4343
echo "::group::Install PyTorch"
44-
pip install torch==$FLASHINFER_CI_TORCH_VERSION --index-url "https://download.pytorch.org/whl/cu${CUDA_MAJOR}${CUDA_MINOR}"
44+
pip install torch==${FLASHINFER_CI_TORCH_VERSION}.* --index-url "https://download.pytorch.org/whl/cu${CUDA_MAJOR}${CUDA_MINOR}"
4545
echo "::endgroup::"
4646

4747
echo "::group::Install build system"

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(self, *args, **kwargs) -> None:
135135
torch_full_version = Version(torch.__version__)
136136
torch_version = f"{torch_full_version.major}.{torch_full_version.minor}"
137137
cmdclass["build_ext"] = NinjaBuildExtension
138-
install_requires = [f"torch == {torch_version}"]
138+
install_requires = [f"torch == {torch_version}.*"]
139139

140140
aot_build_meta = {}
141141
aot_build_meta["cuda_major"] = cuda_version.major

0 commit comments

Comments
 (0)