|
| 1 | +# https://pytorch.org/get-started/locally/ |
| 2 | + |
| 3 | +variable "TORCH_META" { |
| 4 | + default = { |
| 5 | + "2.8.0" = { |
| 6 | + torchvision = "0.23.0" |
| 7 | + } |
| 8 | + "2.7.1" = { |
| 9 | + torchvision = "0.22.1" |
| 10 | + } |
| 11 | + "2.6.0" = { |
| 12 | + torchvision = "0.21.0" |
| 13 | + } |
| 14 | + } |
| 15 | +} |
| 16 | + |
| 17 | +# We need to grab the most compatible wheel for a given CUDA version and Torch version pair |
| 18 | +# At times, this requires grabbing a wheel built for a different CUDA version. |
| 19 | +variable "CUDA_TORCH_COMBINATIONS" { |
| 20 | + default = [ |
| 21 | + { cuda_version = "12.8.1", torch = "2.6.0", whl_src = "126" }, |
| 22 | + { cuda_version = "12.8.1", torch = "2.7.1", whl_src = "128" }, |
| 23 | + { cuda_version = "12.8.1", torch = "2.8.0", whl_src = "128" }, |
| 24 | + |
| 25 | + { cuda_version = "12.9.0", torch = "2.6.0", whl_src = "126" }, |
| 26 | + { cuda_version = "12.9.0", torch = "2.7.1", whl_src = "128" }, |
| 27 | + { cuda_version = "12.9.0", torch = "2.8.0", whl_src = "129" }, |
| 28 | + |
| 29 | + { cuda_version = "13.0.0", torch = "2.6.0", whl_src = "126" }, |
| 30 | + { cuda_version = "13.0.0", torch = "2.7.1", whl_src = "128" }, |
| 31 | + { cuda_version = "13.0.0", torch = "2.8.0", whl_src = "129" } |
| 32 | + ] |
| 33 | +} |
| 34 | + |
| 35 | +variable "COMPATIBLE_BUILDS" { |
| 36 | + default = flatten([ |
| 37 | + for combo in CUDA_TORCH_COMBINATIONS : [ |
| 38 | + for cuda in CUDA_VERSIONS : [ |
| 39 | + for ubuntu in UBUNTU_VERSIONS : { |
| 40 | + ubuntu_version = ubuntu.version |
| 41 | + ubuntu_name = ubuntu.name |
| 42 | + cuda_version = cuda.version |
| 43 | + cuda_code = replace(cuda.version, ".", "") |
| 44 | + wheel_src = combo.whl_src |
| 45 | + torch = combo.torch |
| 46 | + torch_code = replace(combo.torch, ".", "") |
| 47 | + torch_vision = TORCH_META[combo.torch].torchvision |
| 48 | + } if cuda.version == combo.cuda_version && contains(cuda.ubuntu, ubuntu.version) |
| 49 | + ] |
| 50 | + ] |
| 51 | + ]) |
| 52 | +} |
| 53 | + |
| 54 | +group "dev" { |
| 55 | + targets = ["pytorch-ubuntu2404-cu1281-torch280"] |
| 56 | +} |
| 57 | + |
| 58 | +group "default" { |
| 59 | + targets = [ |
| 60 | + for build in COMPATIBLE_BUILDS: |
| 61 | + "pytorch-${build.ubuntu_name}-cu${replace(build.cuda_version, ".", "")}-torch${build.torch_code}" |
| 62 | + ] |
| 63 | +} |
| 64 | + |
| 65 | +target "pytorch-base" { |
| 66 | + context = "official-templates/pytorch" |
| 67 | + dockerfile = "Dockerfile" |
| 68 | + platforms = ["linux/amd64"] |
| 69 | +} |
| 70 | + |
| 71 | +target "pytorch-matrix" { |
| 72 | + matrix = { |
| 73 | + build = COMPATIBLE_BUILDS |
| 74 | + } |
| 75 | + |
| 76 | + name = "pytorch-${build.ubuntu_name}-cu${build.cuda_code}-torch${build.torch_code}" |
| 77 | + |
| 78 | + inherits = ["pytorch-base"] |
| 79 | + |
| 80 | + args = { |
| 81 | + BASE_IMAGE = "runpod/base:${RELEASE_VERSION}${RELEASE_SUFFIX}-cuda${build.cuda_code}-${build.ubuntu_name}" |
| 82 | + WHEEL_SRC = build.wheel_src |
| 83 | + TORCH = "torch==${build.torch} torchvision==${build.torch_vision} torchaudio==${build.torch}" |
| 84 | + } |
| 85 | + |
| 86 | + tags = [ |
| 87 | + "runpod/pytorch:${RELEASE_VERSION}${RELEASE_SUFFIX}-cu${build.cuda_code}-torch${build.torch_code}-${build.ubuntu_name}", |
| 88 | + ] |
| 89 | +} |
0 commit comments