Skip to content

Commit 92ab6e4

Browse files
committed
fix: cluster configuration validation for bool type
As bool is a subtype of int, True/False was considered as 1/0
1 parent e666e0a commit 92ab6e4

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

src/codeflare_sdk/common/utils/unit_test_support.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def createClusterWrongType():
5555
config = ClusterConfiguration(
5656
name="unit-test-cluster",
5757
namespace="ns",
58-
num_workers=2,
58+
num_workers=True,
5959
worker_cpu_requests=[],
6060
worker_cpu_limits=4,
6161
worker_memory_requests=5,

src/codeflare_sdk/ray/cluster/config.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -242,13 +242,15 @@ def _memory_to_resource(self):
242242

243243
def _validate_types(self):
244244
"""Validate the types of all fields in the ClusterConfiguration dataclass."""
245+
errors = []
245246
for field_info in fields(self):
246247
value = getattr(self, field_info.name)
247248
expected_type = field_info.type
248249
if not self._is_type(value, expected_type):
249-
raise TypeError(
250-
f"'{field_info.name}' should be of type {expected_type}"
251-
)
250+
errors.append(f"'{field_info.name}' should be of type {expected_type}.")
251+
252+
if errors:
253+
raise TypeError("Type validation failed:\n" + "\n".join(errors))
252254

253255
@staticmethod
254256
def _is_type(value, expected_type):
@@ -268,6 +270,10 @@ def check_type(value, expected_type):
268270
)
269271
if origin_type is tuple:
270272
return all(check_type(elem, etype) for elem, etype in zip(value, args))
273+
if expected_type is int:
274+
return isinstance(value, int) and not isinstance(value, bool)
275+
if expected_type is bool:
276+
return isinstance(value, bool)
271277
return isinstance(value, expected_type)
272278

273279
return check_type(value, expected_type)

src/codeflare_sdk/ray/cluster/test_config.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,11 @@ def test_all_config_params_aw(mocker):
108108

109109

110110
def test_config_creation_wrong_type():
111-
with pytest.raises(TypeError):
111+
with pytest.raises(TypeError) as error_info:
112112
createClusterWrongType()
113113

114+
assert len(str(error_info.value).splitlines()) == 4
115+
114116

115117
def test_cluster_config_deprecation_conversion(mocker):
116118
config = ClusterConfiguration(

0 commit comments

Comments
 (0)