11from __future__ import annotations
22
33import abc
4+ import logging
45import math
56import typing as t
6- import logging
77
8+ from ..resource import get_resource , system_resources
89from .runnable import Runnable
9- from ..resource import get_resource
10- from ..resource import system_resources
1110
1211logger = logging .getLogger (__name__ )
1312
@@ -18,8 +17,8 @@ class Strategy(abc.ABC):
1817 def get_worker_count (
1918 cls ,
2019 runnable_class : t .Type [Runnable ],
21- resource_request : dict [str , t .Any ],
22- workers_per_resource : int ,
20+ resource_request : dict [str , t .Any ] | None ,
21+ workers_per_resource : int | float ,
2322 ) -> int :
2423 ...
2524
@@ -28,19 +27,15 @@ def get_worker_count(
2827 def get_worker_env (
2928 cls ,
3029 runnable_class : t .Type [Runnable ],
31- resource_request : dict [str , t .Any ],
32- workers_per_resource : int ,
30+ resource_request : dict [str , t .Any ] | None ,
31+ workers_per_resource : int | float ,
3332 worker_index : int ,
3433 ) -> dict [str , t .Any ]:
3534 """
36- Parameters
37- ----------
38- runnable_class : type[Runnable]
39- The runnable class to be run.
40- resource_request : dict[str, Any]
41- The resource request of the runnable.
42- worker_index : int
43- The index of the worker, start from 0.
35+ Args:
36+ runnable_class : The runnable class to be run.
37+ resource_request : The resource request of the runnable.
38+ worker_index : The index of the worker, start from 0.
4439 """
4540 ...
4641
@@ -66,7 +61,7 @@ def get_worker_count(
6661 cls ,
6762 runnable_class : t .Type [Runnable ],
6863 resource_request : dict [str , t .Any ] | None ,
69- workers_per_resource : int ,
64+ workers_per_resource : int | float ,
7065 ) -> int :
7166 if resource_request is None :
7267 resource_request = system_resources ()
@@ -78,7 +73,7 @@ def get_worker_count(
7873 and len (nvidia_gpus ) > 0
7974 and "nvidia.com/gpu" in runnable_class .SUPPORTED_RESOURCES
8075 ):
81- return len (nvidia_gpus ) * workers_per_resource
76+ return math . ceil ( len (nvidia_gpus ) * workers_per_resource )
8277
8378 # use CPU
8479 cpus = get_resource (resource_request , "cpu" )
@@ -90,6 +85,10 @@ def get_worker_count(
9085 )
9186
9287 if runnable_class .SUPPORTS_CPU_MULTI_THREADING :
88+ if isinstance (workers_per_resource , float ):
89+ raise ValueError (
90+ "Fractional CPU multi threading support is not yet supported."
91+ )
9392 return workers_per_resource
9493
9594 return math .ceil (cpus ) * workers_per_resource
@@ -105,31 +104,53 @@ def get_worker_env(
105104 cls ,
106105 runnable_class : t .Type [Runnable ],
107106 resource_request : dict [str , t .Any ] | None ,
108- workers_per_resource : int ,
107+ workers_per_resource : int | float ,
109108 worker_index : int ,
110109 ) -> dict [str , t .Any ]:
111110 """
112- Parameters
113- ----------
114- runnable_class : type[Runnable]
115- The runnable class to be run.
116- resource_request : dict[str, Any]
117- The resource request of the runnable.
118- worker_index : int
119- The index of the worker, start from 0.
111+ Args:
112+ runnable_class : The runnable class to be run.
113+ resource_request : The resource request of the runnable.
114+ worker_index : The index of the worker, start from 0.
120115 """
121116 environ : dict [str , t .Any ] = {}
122117 if resource_request is None :
123118 resource_request = system_resources ()
124-
125119 # use nvidia gpu
126- nvidia_gpus = get_resource (resource_request , "nvidia.com/gpu" )
120+ nvidia_gpus : list [ int ] | None = get_resource (resource_request , "nvidia.com/gpu" )
127121 if (
128122 nvidia_gpus is not None
129123 and len (nvidia_gpus ) > 0
130124 and "nvidia.com/gpu" in runnable_class .SUPPORTED_RESOURCES
131125 ):
132- dev = str (nvidia_gpus [worker_index // workers_per_resource ])
126+ if isinstance (workers_per_resource , float ):
127+ # NOTE: We hit this branch when workers_per_resource is set to
128+ # float, for example 0.5 or 0.25
129+ if workers_per_resource > 1 :
130+ raise ValueError (
131+ "Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case."
132+ )
133+ # We are round the assigned resource here. This means if workers_per_resource=.4
134+ # then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2.
135+ assigned_resource_per_worker = round (1 / workers_per_resource )
136+ if len (nvidia_gpus ) < assigned_resource_per_worker :
137+ logger .warning (
138+ "Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])" ,
139+ nvidia_gpus ,
140+ worker_index ,
141+ assigned_resource_per_worker ,
142+ )
143+ raise IndexError (
144+ f"There aren't enough assigned GPU(s) for given worker id '{ worker_index } ' [required: { assigned_resource_per_worker } ]."
145+ )
146+ assigned_gpu = nvidia_gpus [
147+ assigned_resource_per_worker
148+ * worker_index : assigned_resource_per_worker
149+ * (worker_index + 1 )
150+ ]
151+ dev = "," .join (map (str , assigned_gpu ))
152+ else :
153+ dev = str (nvidia_gpus [worker_index // workers_per_resource ])
133154 environ ["CUDA_VISIBLE_DEVICES" ] = dev
134155 logger .info (
135156 "Environ for worker %s: set CUDA_VISIBLE_DEVICES to %s" ,
0 commit comments