22
22
import warnings
23
23
from dataclasses import dataclass , field , fields
24
24
from typing import Dict , List , Optional , Union , get_args , get_origin
25
- from kubernetes .client import V1Volume , V1VolumeMount
25
+ from kubernetes .client import V1Toleration , V1Volume , V1VolumeMount
26
26
27
27
dir = pathlib .Path (__file__ ).parent .parent .resolve ()
28
28
@@ -58,6 +58,8 @@ class ClusterConfiguration:
58
58
The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
59
59
head_extended_resource_requests:
60
60
A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
61
+ head_tolerations:
62
+ List of tolerations for head nodes.
61
63
min_cpus:
62
64
The minimum number of CPUs to allocate to each worker.
63
65
max_cpus:
@@ -70,6 +72,8 @@ class ClusterConfiguration:
70
72
The maximum amount of memory to allocate to each worker.
71
73
num_gpus:
72
74
The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
75
+ worker_tolerations:
76
+ List of tolerations for worker nodes.
73
77
appwrapper:
74
78
A boolean indicating whether to use an AppWrapper.
75
79
envs:
@@ -110,6 +114,7 @@ class ClusterConfiguration:
110
114
head_extended_resource_requests : Dict [str , Union [str , int ]] = field (
111
115
default_factory = dict
112
116
)
117
+ head_tolerations : Optional [List [V1Toleration ]] = None
113
118
worker_cpu_requests : Union [int , str ] = 1
114
119
worker_cpu_limits : Union [int , str ] = 1
115
120
min_cpus : Optional [Union [int , str ]] = None # Deprecating
@@ -120,6 +125,7 @@ class ClusterConfiguration:
120
125
min_memory : Optional [Union [int , str ]] = None # Deprecating
121
126
max_memory : Optional [Union [int , str ]] = None # Deprecating
122
127
num_gpus : Optional [int ] = None # Deprecating
128
+ worker_tolerations : Optional [List [V1Toleration ]] = None
123
129
appwrapper : bool = False
124
130
envs : Dict [str , str ] = field (default_factory = dict )
125
131
image : str = ""
@@ -272,7 +278,10 @@ def check_type(value, expected_type):
272
278
if origin_type is Union :
273
279
return any (check_type (value , union_type ) for union_type in args )
274
280
if origin_type is list :
275
- return all (check_type (elem , args [0 ]) for elem in value )
281
+ if value is not None :
282
+ return all (check_type (elem , args [0 ]) for elem in (value or []))
283
+ else :
284
+ return True
276
285
if origin_type is dict :
277
286
return all (
278
287
check_type (k , args [0 ]) and check_type (v , args [1 ])
0 commit comments