Skip to content

Commit 1fd5002

Browse files
authored
Merge pull request #89 from quinn-dougherty/main
bump: `torch` to `>=2.1.1` (rm workaround) (#87)
2 parents 1d9dd6f + d276e4a commit 1fd5002

File tree

2 files changed

+57
-62
lines changed

2 files changed

+57
-62
lines changed

python/poetry.lock

Lines changed: 55 additions & 38 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/pyproject.toml

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,8 @@ importlib-metadata = ">=5.1.0"
1212
numpy = [{ version = ">=1.20,<1.25", python = ">=3.8,<3.9" },
1313
{ version = ">=1.24", python = ">=3.9,<3.12" },
1414
{ version = ">=1.26", python = ">=3.12,<3.13" }]
15-
python = ">=3.8"
16-
torch = ">=1.10" # See PyTorch 2 fix below
17-
# PyTorch 2.1.0 Bug Fix PyTorch didn't put their dependencies metadata into all wheels for 2.1.0, so
18-
# it doesn't work with Poetry. This is a known bug - the workaround is to place them manually here
19-
# (from the one wheel that did correctly list them). This was broken in 2.0.1 and the fix wasn't
20-
# made for 2.1.0, however Meta are aware of the issue and once it is fixed (and the torch version
21-
# requirement bumped) this should be removed. Note also the python version is used to specify that
22-
# this is only added where v2 torch is installed (as per the torch version requirement above).
23-
# https://github.com/pytorch/pytorch/issues/100974
24-
# https://github.com/python-poetry/poetry/issues/7902#issuecomment-1583078794
25-
nvidia-cuda-nvrtc-cu12 = { version = "==12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
26-
nvidia-cuda-runtime-cu12 = { version = "==12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
27-
nvidia-cuda-cupti-cu12 = { version = "==12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
28-
nvidia-cudnn-cu12 = { version = "==8.9.2.26", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
29-
nvidia-cublas-cu12 = { version = "==12.1.3.1", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
30-
nvidia-cufft-cu12 = { version = "==11.0.2.54", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
31-
nvidia-curand-cu12 = { version = "==10.3.2.106", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
32-
nvidia-cusolver-cu12 = { version = "==11.4.5.107", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
33-
nvidia-cusparse-cu12 = { version = "==12.1.0.106", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
34-
nvidia-nccl-cu12 = { version = "==2.18.1", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
35-
nvidia-nvtx-cu12 = { version = "==12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
36-
triton = { version = "==2.1.0", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
37-
# End PyTorch 2.1.0 Bug Fix
15+
python = ">=3.8"
16+
torch = ">=2.1.1"
3817

3918
[tool.poetry.group.dev.dependencies]
4019
autopep8 = ">=2.0"
@@ -44,7 +23,6 @@ pytest = ">=7.2"
4423
snapshottest = ">=0.6"
4524
twine = ">=4.0.1"
4625

47-
4826
[tool.poetry.group.jupyter.dependencies]
4927
jupyterlab = ">=3.5"
5028

0 commit comments

Comments
 (0)