Skip to content

Commit 89e5fda

Browse files
authored
feat(scheduling): support assigning multiple GPU per worker instance (bentoml#3950)
* feat(scheduling): support assigning multiple GPU per worker instance Signed-off-by: aarnphm-ec2-dev <[email protected]> * tests: add tests case for fractional GPU Signed-off-by: aarnphm-ec2-dev <[email protected]> * fix(strategy): arithmetic counting Signed-off-by: aarnphm-ec2-dev <[email protected]> * chore(logs): add a error log about the exception Signed-off-by: aarnphm-ec2-dev <[email protected]> * fix(strategy): rounding the assigned resources Signed-off-by: Aaron <[email protected]> * chore: update grammar about exception message --------- Signed-off-by: aarnphm-ec2-dev <[email protected]> Signed-off-by: Aaron <[email protected]>
1 parent 5122f91 commit 89e5fda

File tree

4 files changed

+109
-32
lines changed

4 files changed

+109
-32
lines changed

src/bentoml/_internal/configuration/v1/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@
150150
# NOTE: there is a distinction between being unset and None here; if set to 'None'
151151
# in configuration for a specific runner, it will override the global configuration.
152152
s.Optional("resources"): s.Or({s.Optional(str): object}, lambda s: s == "system", None), # type: ignore (incomplete schema typing)
153-
s.Optional("workers_per_resource"): s.And(int, ensure_larger_than_zero),
153+
s.Optional("workers_per_resource"): s.And(
154+
s.Or(int, float), ensure_larger_than_zero
155+
),
154156
s.Optional("logging"): {
155157
s.Optional("access"): {
156158
s.Optional("enabled"): bool,

src/bentoml/_internal/runner/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __getattr__(self, item: str) -> t.Any:
130130

131131
runner_methods: list[RunnerMethod[t.Any, t.Any, t.Any]]
132132
scheduling_strategy: type[Strategy]
133-
workers_per_resource: int = 1
133+
workers_per_resource: int | float = 1
134134
runnable_init_params: dict[str, t.Any] = attr.field(
135135
default=None, converter=attr.converters.default_if_none(factory=dict)
136136
)

src/bentoml/_internal/runner/strategy.py

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from __future__ import annotations
22

33
import abc
4+
import logging
45
import math
56
import typing as t
6-
import logging
77

8+
from ..resource import get_resource, system_resources
89
from .runnable import Runnable
9-
from ..resource import get_resource
10-
from ..resource import system_resources
1110

1211
logger = logging.getLogger(__name__)
1312

@@ -18,8 +17,8 @@ class Strategy(abc.ABC):
1817
def get_worker_count(
1918
cls,
2019
runnable_class: t.Type[Runnable],
21-
resource_request: dict[str, t.Any],
22-
workers_per_resource: int,
20+
resource_request: dict[str, t.Any] | None,
21+
workers_per_resource: int | float,
2322
) -> int:
2423
...
2524

@@ -28,19 +27,15 @@ def get_worker_count(
2827
def get_worker_env(
2928
cls,
3029
runnable_class: t.Type[Runnable],
31-
resource_request: dict[str, t.Any],
32-
workers_per_resource: int,
30+
resource_request: dict[str, t.Any] | None,
31+
workers_per_resource: int | float,
3332
worker_index: int,
3433
) -> dict[str, t.Any]:
3534
"""
36-
Parameters
37-
----------
38-
runnable_class : type[Runnable]
39-
The runnable class to be run.
40-
resource_request : dict[str, Any]
41-
The resource request of the runnable.
42-
worker_index : int
43-
The index of the worker, start from 0.
35+
Args:
36+
runnable_class : The runnable class to be run.
37+
resource_request : The resource request of the runnable.
38+
worker_index : The index of the worker, start from 0.
4439
"""
4540
...
4641

@@ -66,7 +61,7 @@ def get_worker_count(
6661
cls,
6762
runnable_class: t.Type[Runnable],
6863
resource_request: dict[str, t.Any] | None,
69-
workers_per_resource: int,
64+
workers_per_resource: int | float,
7065
) -> int:
7166
if resource_request is None:
7267
resource_request = system_resources()
@@ -78,7 +73,7 @@ def get_worker_count(
7873
and len(nvidia_gpus) > 0
7974
and "nvidia.com/gpu" in runnable_class.SUPPORTED_RESOURCES
8075
):
81-
return len(nvidia_gpus) * workers_per_resource
76+
return math.ceil(len(nvidia_gpus) * workers_per_resource)
8277

8378
# use CPU
8479
cpus = get_resource(resource_request, "cpu")
@@ -90,6 +85,10 @@ def get_worker_count(
9085
)
9186

9287
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
88+
if isinstance(workers_per_resource, float):
89+
raise ValueError(
90+
"Fractional CPU multi threading support is not yet supported."
91+
)
9392
return workers_per_resource
9493

9594
return math.ceil(cpus) * workers_per_resource
@@ -105,31 +104,53 @@ def get_worker_env(
105104
cls,
106105
runnable_class: t.Type[Runnable],
107106
resource_request: dict[str, t.Any] | None,
108-
workers_per_resource: int,
107+
workers_per_resource: int | float,
109108
worker_index: int,
110109
) -> dict[str, t.Any]:
111110
"""
112-
Parameters
113-
----------
114-
runnable_class : type[Runnable]
115-
The runnable class to be run.
116-
resource_request : dict[str, Any]
117-
The resource request of the runnable.
118-
worker_index : int
119-
The index of the worker, start from 0.
111+
Args:
112+
runnable_class : The runnable class to be run.
113+
resource_request : The resource request of the runnable.
114+
worker_index : The index of the worker, start from 0.
120115
"""
121116
environ: dict[str, t.Any] = {}
122117
if resource_request is None:
123118
resource_request = system_resources()
124-
125119
# use nvidia gpu
126-
nvidia_gpus = get_resource(resource_request, "nvidia.com/gpu")
120+
nvidia_gpus: list[int] | None = get_resource(resource_request, "nvidia.com/gpu")
127121
if (
128122
nvidia_gpus is not None
129123
and len(nvidia_gpus) > 0
130124
and "nvidia.com/gpu" in runnable_class.SUPPORTED_RESOURCES
131125
):
132-
dev = str(nvidia_gpus[worker_index // workers_per_resource])
126+
if isinstance(workers_per_resource, float):
127+
# NOTE: We hit this branch when workers_per_resource is set to
128+
# float, for example 0.5 or 0.25
129+
if workers_per_resource > 1:
130+
raise ValueError(
131+
"Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case."
132+
)
133+
# We are round the assigned resource here. This means if workers_per_resource=.4
134+
# then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2.
135+
assigned_resource_per_worker = round(1 / workers_per_resource)
136+
if len(nvidia_gpus) < assigned_resource_per_worker:
137+
logger.warning(
138+
"Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])",
139+
nvidia_gpus,
140+
worker_index,
141+
assigned_resource_per_worker,
142+
)
143+
raise IndexError(
144+
f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}]."
145+
)
146+
assigned_gpu = nvidia_gpus[
147+
assigned_resource_per_worker
148+
* worker_index : assigned_resource_per_worker
149+
* (worker_index + 1)
150+
]
151+
dev = ",".join(map(str, assigned_gpu))
152+
else:
153+
dev = str(nvidia_gpus[worker_index // workers_per_resource])
133154
environ["CUDA_VISIBLE_DEVICES"] = dev
134155
logger.info(
135156
"Environ for worker %s: set CUDA_VISIBLE_DEVICES to %s",

tests/unit/_internal/runner/test_strategy.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from _pytest.monkeypatch import MonkeyPatch
99

1010
import bentoml
11-
from bentoml._internal.runner import strategy
1211
from bentoml._internal.resource import get_resource
12+
from bentoml._internal.runner import strategy
1313
from bentoml._internal.runner.strategy import DefaultStrategy
1414

1515

@@ -51,6 +51,29 @@ def test_default_gpu_strategy(monkeypatch: MonkeyPatch):
5151
== 4
5252
)
5353

54+
assert (
55+
DefaultStrategy.get_worker_count(GPURunnable, {"nvidia.com/gpu": [2, 7]}, 0.5)
56+
== 1
57+
)
58+
assert (
59+
DefaultStrategy.get_worker_count(
60+
GPURunnable, {"nvidia.com/gpu": [2, 7, 9]}, 0.5
61+
)
62+
== 2
63+
)
64+
assert (
65+
DefaultStrategy.get_worker_count(
66+
GPURunnable, {"nvidia.com/gpu": [2, 7, 8, 9]}, 0.5
67+
)
68+
== 2
69+
)
70+
assert (
71+
DefaultStrategy.get_worker_count(
72+
GPURunnable, {"nvidia.com/gpu": [2, 5, 7, 8, 9]}, 0.4
73+
)
74+
== 2
75+
)
76+
5477
envs = DefaultStrategy.get_worker_env(GPURunnable, {"nvidia.com/gpu": 2}, 1, 0)
5578
assert envs.get("CUDA_VISIBLE_DEVICES") == "0"
5679
envs = DefaultStrategy.get_worker_env(GPURunnable, {"nvidia.com/gpu": 2}, 1, 1)
@@ -69,6 +92,37 @@ def test_default_gpu_strategy(monkeypatch: MonkeyPatch):
6992
envs = DefaultStrategy.get_worker_env(GPURunnable, {"nvidia.com/gpu": [2, 7]}, 2, 2)
7093
assert envs.get("CUDA_VISIBLE_DEVICES") == "7"
7194

95+
envs = DefaultStrategy.get_worker_env(
96+
GPURunnable, {"nvidia.com/gpu": [2, 7]}, 0.5, 0
97+
)
98+
assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7"
99+
100+
envs = DefaultStrategy.get_worker_env(
101+
GPURunnable, {"nvidia.com/gpu": [2, 7, 8, 9]}, 0.5, 0
102+
)
103+
assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7"
104+
envs = DefaultStrategy.get_worker_env(
105+
GPURunnable, {"nvidia.com/gpu": [2, 7, 8, 9]}, 0.5, 1
106+
)
107+
assert envs.get("CUDA_VISIBLE_DEVICES") == "8,9"
108+
envs = DefaultStrategy.get_worker_env(
109+
GPURunnable, {"nvidia.com/gpu": [2, 7, 8, 9]}, 0.25, 0
110+
)
111+
assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7,8,9"
112+
113+
envs = DefaultStrategy.get_worker_env(
114+
GPURunnable, {"nvidia.com/gpu": [2, 6, 7, 8, 9]}, 0.4, 0
115+
)
116+
assert envs.get("CUDA_VISIBLE_DEVICES") == "2,6"
117+
envs = DefaultStrategy.get_worker_env(
118+
GPURunnable, {"nvidia.com/gpu": [2, 6, 7, 8, 9]}, 0.4, 1
119+
)
120+
assert envs.get("CUDA_VISIBLE_DEVICES") == "7,8"
121+
envs = DefaultStrategy.get_worker_env(
122+
GPURunnable, {"nvidia.com/gpu": [2, 6, 7, 8, 9]}, 0.4, 2
123+
)
124+
assert envs.get("CUDA_VISIBLE_DEVICES") == "9"
125+
72126

73127
def test_default_cpu_strategy(monkeypatch: MonkeyPatch):
74128
monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource)

0 commit comments

Comments
 (0)