Skip to content

Commit 830cbce

Browse files
authoredOct 23, 2023
Adding direct KubeRay compatibility to the SDK (#358)
* Added component generation * Added multi-resource YAML support * Cluster.up on ray cluster object * Basic status and down for RayCluster * Finished up/down and added unit tests * Remove unused utils import * Applied review feedback * Changed naming of internal funcs * Review feedback applied, auto-select * OAuth conflict resolution
1 parent 2441f4f commit 830cbce

File tree

7 files changed

+450
-84
lines changed

7 files changed

+450
-84
lines changed
 

‎src/codeflare_sdk/cluster/cluster.py

+161-61
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, config: ClusterConfiguration):
7070
self.config = config
7171
self.app_wrapper_yaml = self.create_app_wrapper()
7272
self.app_wrapper_name = self.app_wrapper_yaml.split(".")[0]
73-
self._client = None
73+
self._job_submission_client = None
7474

7575
@property
7676
def _client_headers(self):
@@ -86,23 +86,25 @@ def _client_verify_tls(self):
8686
return not self.config.openshift_oauth
8787

8888
@property
89-
def client(self):
90-
if self._client:
91-
return self._client
89+
def job_client(self):
90+
if self._job_submission_client:
91+
return self._job_submission_client
9292
if self.config.openshift_oauth:
9393
print(
9494
api_config_handler().configuration.get_api_key_with_prefix(
9595
"authorization"
9696
)
9797
)
98-
self._client = JobSubmissionClient(
98+
self._job_submission_client = JobSubmissionClient(
9999
self.cluster_dashboard_uri(),
100100
headers=self._client_headers,
101101
verify=self._client_verify_tls,
102102
)
103103
else:
104-
self._client = JobSubmissionClient(self.cluster_dashboard_uri())
105-
return self._client
104+
self._job_submission_client = JobSubmissionClient(
105+
self.cluster_dashboard_uri()
106+
)
107+
return self._job_submission_client
106108

107109
def evaluate_dispatch_priority(self):
108110
priority_class = self.config.dispatch_priority
@@ -141,6 +143,10 @@ def create_app_wrapper(self):
141143

142144
# Before attempting to create the cluster AW, let's evaluate the ClusterConfig
143145
if self.config.dispatch_priority:
146+
if not self.config.mcad:
147+
raise ValueError(
148+
"Invalid Cluster Configuration, cannot have dispatch priority without MCAD"
149+
)
144150
priority_val = self.evaluate_dispatch_priority()
145151
if priority_val == None:
146152
raise ValueError(
@@ -163,6 +169,7 @@ def create_app_wrapper(self):
163169
template = self.config.template
164170
image = self.config.image
165171
instascale = self.config.instascale
172+
mcad = self.config.mcad
166173
instance_types = self.config.machine_types
167174
env = self.config.envs
168175
local_interactive = self.config.local_interactive
@@ -183,6 +190,7 @@ def create_app_wrapper(self):
183190
template=template,
184191
image=image,
185192
instascale=instascale,
193+
mcad=mcad,
186194
instance_types=instance_types,
187195
env=env,
188196
local_interactive=local_interactive,
@@ -207,15 +215,18 @@ def up(self):
207215
try:
208216
config_check()
209217
api_instance = client.CustomObjectsApi(api_config_handler())
210-
with open(self.app_wrapper_yaml) as f:
211-
aw = yaml.load(f, Loader=yaml.FullLoader)
212-
api_instance.create_namespaced_custom_object(
213-
group="workload.codeflare.dev",
214-
version="v1beta1",
215-
namespace=namespace,
216-
plural="appwrappers",
217-
body=aw,
218-
)
218+
if self.config.mcad:
219+
with open(self.app_wrapper_yaml) as f:
220+
aw = yaml.load(f, Loader=yaml.FullLoader)
221+
api_instance.create_namespaced_custom_object(
222+
group="workload.codeflare.dev",
223+
version="v1beta1",
224+
namespace=namespace,
225+
plural="appwrappers",
226+
body=aw,
227+
)
228+
else:
229+
self._component_resources_up(namespace, api_instance)
219230
except Exception as e: # pragma: no cover
220231
return _kube_api_error_handling(e)
221232

@@ -228,13 +239,16 @@ def down(self):
228239
try:
229240
config_check()
230241
api_instance = client.CustomObjectsApi(api_config_handler())
231-
api_instance.delete_namespaced_custom_object(
232-
group="workload.codeflare.dev",
233-
version="v1beta1",
234-
namespace=namespace,
235-
plural="appwrappers",
236-
name=self.app_wrapper_name,
237-
)
242+
if self.config.mcad:
243+
api_instance.delete_namespaced_custom_object(
244+
group="workload.codeflare.dev",
245+
version="v1beta1",
246+
namespace=namespace,
247+
plural="appwrappers",
248+
name=self.app_wrapper_name,
249+
)
250+
else:
251+
self._component_resources_down(namespace, api_instance)
238252
except Exception as e: # pragma: no cover
239253
return _kube_api_error_handling(e)
240254

@@ -252,42 +266,46 @@ def status(
252266
"""
253267
ready = False
254268
status = CodeFlareClusterStatus.UNKNOWN
255-
# check the app wrapper status
256-
appwrapper = _app_wrapper_status(self.config.name, self.config.namespace)
257-
if appwrapper:
258-
if appwrapper.status in [
259-
AppWrapperStatus.RUNNING,
260-
AppWrapperStatus.COMPLETED,
261-
AppWrapperStatus.RUNNING_HOLD_COMPLETION,
262-
]:
263-
ready = False
264-
status = CodeFlareClusterStatus.STARTING
265-
elif appwrapper.status in [
266-
AppWrapperStatus.FAILED,
267-
AppWrapperStatus.DELETED,
268-
]:
269-
ready = False
270-
status = CodeFlareClusterStatus.FAILED # should deleted be separate
271-
return status, ready # exit early, no need to check ray status
272-
elif appwrapper.status in [
273-
AppWrapperStatus.PENDING,
274-
AppWrapperStatus.QUEUEING,
275-
]:
276-
ready = False
277-
if appwrapper.status == AppWrapperStatus.PENDING:
278-
status = CodeFlareClusterStatus.QUEUED
279-
else:
280-
status = CodeFlareClusterStatus.QUEUEING
281-
if print_to_console:
282-
pretty_print.print_app_wrappers_status([appwrapper])
283-
return (
284-
status,
285-
ready,
286-
) # no need to check the ray status since still in queue
269+
if self.config.mcad:
270+
# check the app wrapper status
271+
appwrapper = _app_wrapper_status(self.config.name, self.config.namespace)
272+
if appwrapper:
273+
if appwrapper.status in [
274+
AppWrapperStatus.RUNNING,
275+
AppWrapperStatus.COMPLETED,
276+
AppWrapperStatus.RUNNING_HOLD_COMPLETION,
277+
]:
278+
ready = False
279+
status = CodeFlareClusterStatus.STARTING
280+
elif appwrapper.status in [
281+
AppWrapperStatus.FAILED,
282+
AppWrapperStatus.DELETED,
283+
]:
284+
ready = False
285+
status = CodeFlareClusterStatus.FAILED # should deleted be separate
286+
return status, ready # exit early, no need to check ray status
287+
elif appwrapper.status in [
288+
AppWrapperStatus.PENDING,
289+
AppWrapperStatus.QUEUEING,
290+
]:
291+
ready = False
292+
if appwrapper.status == AppWrapperStatus.PENDING:
293+
status = CodeFlareClusterStatus.QUEUED
294+
else:
295+
status = CodeFlareClusterStatus.QUEUEING
296+
if print_to_console:
297+
pretty_print.print_app_wrappers_status([appwrapper])
298+
return (
299+
status,
300+
ready,
301+
) # no need to check the ray status since still in queue
287302

288303
# check the ray cluster status
289304
cluster = _ray_cluster_status(self.config.name, self.config.namespace)
290-
if cluster and not cluster.status == RayClusterStatus.UNKNOWN:
305+
if cluster:
306+
if cluster.status == RayClusterStatus.UNKNOWN:
307+
ready = False
308+
status = CodeFlareClusterStatus.STARTING
291309
if cluster.status == RayClusterStatus.READY:
292310
ready = True
293311
status = CodeFlareClusterStatus.READY
@@ -407,19 +425,19 @@ def list_jobs(self) -> List:
407425
"""
408426
This method accesses the head ray node in your cluster and lists the running jobs.
409427
"""
410-
return self.client.list_jobs()
428+
return self.job_client.list_jobs()
411429

412430
def job_status(self, job_id: str) -> str:
413431
"""
414432
This method accesses the head ray node in your cluster and returns the job status for the provided job id.
415433
"""
416-
return self.client.get_job_status(job_id)
434+
return self.job_client.get_job_status(job_id)
417435

418436
def job_logs(self, job_id: str) -> str:
419437
"""
420438
This method accesses the head ray node in your cluster and returns the logs for the provided job id.
421439
"""
422-
return self.client.get_job_logs(job_id)
440+
return self.job_client.get_job_logs(job_id)
423441

424442
def torchx_config(
425443
self, working_dir: str = None, requirements: str = None
@@ -435,7 +453,7 @@ def torchx_config(
435453
to_return["requirements"] = requirements
436454
return to_return
437455

438-
def from_k8_cluster_object(rc):
456+
def from_k8_cluster_object(rc, mcad=True):
439457
machine_types = (
440458
rc["metadata"]["labels"]["orderedinstance"].split("_")
441459
if "orderedinstance" in rc["metadata"]["labels"]
@@ -474,6 +492,7 @@ def from_k8_cluster_object(rc):
474492
0
475493
]["image"],
476494
local_interactive=local_interactive,
495+
mcad=mcad,
477496
)
478497
return Cluster(cluster_config)
479498

@@ -484,6 +503,66 @@ def local_client_url(self):
484503
else:
485504
return "None"
486505

506+
def _component_resources_up(
507+
self, namespace: str, api_instance: client.CustomObjectsApi
508+
):
509+
with open(self.app_wrapper_yaml) as f:
510+
yamls = yaml.load_all(f, Loader=yaml.FullLoader)
511+
for resource in yamls:
512+
if resource["kind"] == "RayCluster":
513+
api_instance.create_namespaced_custom_object(
514+
group="ray.io",
515+
version="v1alpha1",
516+
namespace=namespace,
517+
plural="rayclusters",
518+
body=resource,
519+
)
520+
elif resource["kind"] == "Route":
521+
api_instance.create_namespaced_custom_object(
522+
group="route.openshift.io",
523+
version="v1",
524+
namespace=namespace,
525+
plural="routes",
526+
body=resource,
527+
)
528+
elif resource["kind"] == "Secret":
529+
secret_instance = client.CoreV1Api(api_config_handler())
530+
secret_instance.create_namespaced_secret(
531+
namespace=namespace,
532+
body=resource,
533+
)
534+
535+
def _component_resources_down(
536+
self, namespace: str, api_instance: client.CustomObjectsApi
537+
):
538+
with open(self.app_wrapper_yaml) as f:
539+
yamls = yaml.load_all(f, Loader=yaml.FullLoader)
540+
for resource in yamls:
541+
if resource["kind"] == "RayCluster":
542+
api_instance.delete_namespaced_custom_object(
543+
group="ray.io",
544+
version="v1alpha1",
545+
namespace=namespace,
546+
plural="rayclusters",
547+
name=self.app_wrapper_name,
548+
)
549+
elif resource["kind"] == "Route":
550+
name = resource["metadata"]["name"]
551+
api_instance.delete_namespaced_custom_object(
552+
group="route.openshift.io",
553+
version="v1",
554+
namespace=namespace,
555+
plural="routes",
556+
name=name,
557+
)
558+
elif resource["kind"] == "Secret":
559+
name = resource["metadata"]["name"]
560+
secret_instance = client.CoreV1Api(api_config_handler())
561+
secret_instance.delete_namespaced_secret(
562+
namespace=namespace,
563+
name=name,
564+
)
565+
487566

488567
def list_all_clusters(namespace: str, print_to_console: bool = True):
489568
"""
@@ -549,13 +628,33 @@ def get_cluster(cluster_name: str, namespace: str = "default"):
549628

550629
for rc in rcs["items"]:
551630
if rc["metadata"]["name"] == cluster_name:
552-
return Cluster.from_k8_cluster_object(rc)
631+
mcad = _check_aw_exists(cluster_name, namespace)
632+
return Cluster.from_k8_cluster_object(rc, mcad=mcad)
553633
raise FileNotFoundError(
554634
f"Cluster {cluster_name} is not found in {namespace} namespace"
555635
)
556636

557637

558638
# private methods
639+
def _check_aw_exists(name: str, namespace: str) -> bool:
640+
try:
641+
config_check()
642+
api_instance = client.CustomObjectsApi(api_config_handler())
643+
aws = api_instance.list_namespaced_custom_object(
644+
group="workload.codeflare.dev",
645+
version="v1beta1",
646+
namespace=namespace,
647+
plural="appwrappers",
648+
)
649+
except Exception as e: # pragma: no cover
650+
return _kube_api_error_handling(e, print_error=False)
651+
652+
for aw in aws["items"]:
653+
if aw["metadata"]["name"] == name:
654+
return True
655+
return False
656+
657+
559658
def _get_ingress_domain():
560659
try:
561660
config_check()
@@ -660,6 +759,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
660759

661760
config_check()
662761
api_instance = client.CustomObjectsApi(api_config_handler())
762+
# UPDATE THIS
663763
routes = api_instance.list_namespaced_custom_object(
664764
group="route.openshift.io",
665765
version="v1",

‎src/codeflare_sdk/cluster/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class ClusterConfiguration:
4646
num_gpus: int = 0
4747
template: str = f"{dir}/templates/base-template.yaml"
4848
instascale: bool = False
49+
mcad: bool = True
4950
envs: dict = field(default_factory=dict)
5051
image: str = "quay.io/project-codeflare/ray:latest-py39-cu118"
5152
local_interactive: bool = False

‎src/codeflare_sdk/job/jobs.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@
2222
from torchx.schedulers.ray_scheduler import RayScheduler
2323
from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo
2424

25-
from ray.job_submission import JobSubmissionClient
26-
27-
import openshift as oc
2825

2926
if TYPE_CHECKING:
3027
from ..cluster.cluster import Cluster
@@ -96,9 +93,9 @@ def __init__(
9693

9794
def _dry_run(self, cluster: "Cluster"):
9895
j = f"{cluster.config.num_workers}x{max(cluster.config.num_gpus, 1)}" # # of proc. = # of gpus
99-
runner = get_runner(ray_client=cluster.client)
96+
runner = get_runner(ray_client=cluster.job_client)
10097
runner._scheduler_instances["ray"] = RayScheduler(
101-
session_name=runner._name, ray_client=cluster.client
98+
session_name=runner._name, ray_client=cluster.job_client
10299
)
103100
return (
104101
runner.dryrun(

‎src/codeflare_sdk/utils/generate_yaml.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,19 @@ def _create_oauth_sidecar_object(
457457
)
458458

459459

460+
def write_components(user_yaml: dict, output_file_name: str):
461+
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
462+
open(output_file_name, "w").close()
463+
with open(output_file_name, "a") as outfile:
464+
for component in components:
465+
if "generictemplate" in component:
466+
outfile.write("---\n")
467+
yaml.dump(
468+
component["generictemplate"], outfile, default_flow_style=False
469+
)
470+
print(f"Written to: {output_file_name}")
471+
472+
460473
def generate_appwrapper(
461474
name: str,
462475
namespace: str,
@@ -472,6 +485,7 @@ def generate_appwrapper(
472485
template: str,
473486
image: str,
474487
instascale: bool,
488+
mcad: bool,
475489
instance_types: list,
476490
env,
477491
local_interactive: bool,
@@ -527,5 +541,8 @@ def generate_appwrapper(
527541
enable_openshift_oauth(user_yaml, cluster_name, namespace)
528542

529543
outfile = appwrapper_name + ".yaml"
530-
write_user_appwrapper(user_yaml, outfile)
544+
if not mcad:
545+
write_components(user_yaml, outfile)
546+
else:
547+
write_user_appwrapper(user_yaml, outfile)
531548
return outfile

‎src/codeflare_sdk/utils/kube_api_helpers.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323

2424

2525
# private methods
26-
def _kube_api_error_handling(e: Exception): # pragma: no cover
26+
def _kube_api_error_handling(
27+
e: Exception, print_error: bool = True
28+
): # pragma: no cover
2729
perm_msg = (
2830
"Action not permitted, have you put in correct/up-to-date auth credentials?"
2931
)
@@ -32,11 +34,13 @@ def _kube_api_error_handling(e: Exception): # pragma: no cover
3234
if type(e) == config.ConfigException:
3335
raise PermissionError(perm_msg)
3436
if type(e) == executing.executing.NotOneValueFound:
35-
print(nf_msg)
37+
if print_error:
38+
print(nf_msg)
3639
return
3740
if type(e) == client.ApiException:
3841
if e.reason == "Not Found":
39-
print(nf_msg)
42+
if print_error:
43+
print(nf_msg)
4044
return
4145
elif e.reason == "Unauthorized" or e.reason == "Forbidden":
4246
raise PermissionError(perm_msg)

‎tests/test-case-no-mcad.yamls

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
---
2+
apiVersion: ray.io/v1alpha1
3+
kind: RayCluster
4+
metadata:
5+
labels:
6+
appwrapper.mcad.ibm.com: unit-test-cluster-ray
7+
controller-tools.k8s.io: '1.0'
8+
name: unit-test-cluster-ray
9+
namespace: ns
10+
spec:
11+
autoscalerOptions:
12+
idleTimeoutSeconds: 60
13+
imagePullPolicy: Always
14+
resources:
15+
limits:
16+
cpu: 500m
17+
memory: 512Mi
18+
requests:
19+
cpu: 500m
20+
memory: 512Mi
21+
upscalingMode: Default
22+
enableInTreeAutoscaling: false
23+
headGroupSpec:
24+
rayStartParams:
25+
block: 'true'
26+
dashboard-host: 0.0.0.0
27+
num-gpus: '0'
28+
serviceType: ClusterIP
29+
template:
30+
spec:
31+
affinity:
32+
nodeAffinity:
33+
requiredDuringSchedulingIgnoredDuringExecution:
34+
nodeSelectorTerms:
35+
- matchExpressions:
36+
- key: unit-test-cluster-ray
37+
operator: In
38+
values:
39+
- unit-test-cluster-ray
40+
containers:
41+
- env:
42+
- name: MY_POD_IP
43+
valueFrom:
44+
fieldRef:
45+
fieldPath: status.podIP
46+
- name: RAY_USE_TLS
47+
value: '0'
48+
- name: RAY_TLS_SERVER_CERT
49+
value: /home/ray/workspace/tls/server.crt
50+
- name: RAY_TLS_SERVER_KEY
51+
value: /home/ray/workspace/tls/server.key
52+
- name: RAY_TLS_CA_CERT
53+
value: /home/ray/workspace/tls/ca.crt
54+
image: quay.io/project-codeflare/ray:latest-py39-cu118
55+
imagePullPolicy: Always
56+
lifecycle:
57+
preStop:
58+
exec:
59+
command:
60+
- /bin/sh
61+
- -c
62+
- ray stop
63+
name: ray-head
64+
ports:
65+
- containerPort: 6379
66+
name: gcs
67+
- containerPort: 8265
68+
name: dashboard
69+
- containerPort: 10001
70+
name: client
71+
resources:
72+
limits:
73+
cpu: 2
74+
memory: 8G
75+
nvidia.com/gpu: 0
76+
requests:
77+
cpu: 2
78+
memory: 8G
79+
nvidia.com/gpu: 0
80+
imagePullSecrets:
81+
- name: unit-test-pull-secret
82+
rayVersion: 2.7.0
83+
workerGroupSpecs:
84+
- groupName: small-group-unit-test-cluster-ray
85+
maxReplicas: 2
86+
minReplicas: 2
87+
rayStartParams:
88+
block: 'true'
89+
num-gpus: '7'
90+
replicas: 2
91+
template:
92+
metadata:
93+
annotations:
94+
key: value
95+
labels:
96+
key: value
97+
spec:
98+
affinity:
99+
nodeAffinity:
100+
requiredDuringSchedulingIgnoredDuringExecution:
101+
nodeSelectorTerms:
102+
- matchExpressions:
103+
- key: unit-test-cluster-ray
104+
operator: In
105+
values:
106+
- unit-test-cluster-ray
107+
containers:
108+
- env:
109+
- name: MY_POD_IP
110+
valueFrom:
111+
fieldRef:
112+
fieldPath: status.podIP
113+
- name: RAY_USE_TLS
114+
value: '0'
115+
- name: RAY_TLS_SERVER_CERT
116+
value: /home/ray/workspace/tls/server.crt
117+
- name: RAY_TLS_SERVER_KEY
118+
value: /home/ray/workspace/tls/server.key
119+
- name: RAY_TLS_CA_CERT
120+
value: /home/ray/workspace/tls/ca.crt
121+
image: quay.io/project-codeflare/ray:latest-py39-cu118
122+
lifecycle:
123+
preStop:
124+
exec:
125+
command:
126+
- /bin/sh
127+
- -c
128+
- ray stop
129+
name: machine-learning
130+
resources:
131+
limits:
132+
cpu: 4
133+
memory: 6G
134+
nvidia.com/gpu: 7
135+
requests:
136+
cpu: 3
137+
memory: 5G
138+
nvidia.com/gpu: 7
139+
imagePullSecrets:
140+
- name: unit-test-pull-secret
141+
initContainers:
142+
- command:
143+
- sh
144+
- -c
145+
- until nslookup $RAY_IP.$(cat /var/run/secrets/kubernetes.io/serviceaccount/namespace).svc.cluster.local;
146+
do echo waiting for myservice; sleep 2; done
147+
image: busybox:1.28
148+
name: init-myservice
149+
---
150+
apiVersion: route.openshift.io/v1
151+
kind: Route
152+
metadata:
153+
labels:
154+
odh-ray-cluster-service: unit-test-cluster-ray-head-svc
155+
name: ray-dashboard-unit-test-cluster-ray
156+
namespace: ns
157+
spec:
158+
port:
159+
targetPort: dashboard
160+
to:
161+
kind: Service
162+
name: unit-test-cluster-ray-head-svc

‎tests/unit_test.py

+99-14
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
get_cluster,
3535
_app_wrapper_status,
3636
_ray_cluster_status,
37+
_get_ingress_domain,
3738
)
3839
from codeflare_sdk.cluster.auth import (
3940
TokenAuthentication,
@@ -242,6 +243,8 @@ def test_config_creation():
242243
assert config.machine_types == ["cpu.small", "gpu.large"]
243244
assert config.image_pull_secrets == ["unit-test-pull-secret"]
244245
assert config.dispatch_priority == None
246+
assert config.mcad == True
247+
assert config.local_interactive == False
245248

246249

247250
def test_cluster_creation():
@@ -253,6 +256,20 @@ def test_cluster_creation():
253256
)
254257

255258

259+
def test_cluster_creation_no_mcad():
260+
config = createClusterConfig()
261+
config.name = "unit-test-cluster-ray"
262+
config.mcad = False
263+
cluster = Cluster(config)
264+
assert cluster.app_wrapper_yaml == "unit-test-cluster-ray.yaml"
265+
assert cluster.app_wrapper_name == "unit-test-cluster-ray"
266+
assert filecmp.cmp(
267+
"unit-test-cluster-ray.yaml",
268+
f"{parent}/tests/test-case-no-mcad.yamls",
269+
shallow=True,
270+
)
271+
272+
256273
def test_cluster_creation_priority(mocker):
257274
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
258275
mocker.patch(
@@ -286,23 +303,49 @@ def test_default_cluster_creation(mocker):
286303

287304

288305
def arg_check_apply_effect(group, version, namespace, plural, body, *args):
289-
assert group == "workload.codeflare.dev"
290-
assert version == "v1beta1"
291306
assert namespace == "ns"
292-
assert plural == "appwrappers"
293-
with open("unit-test-cluster.yaml") as f:
294-
aw = yaml.load(f, Loader=yaml.FullLoader)
295-
assert body == aw
296307
assert args == tuple()
308+
if plural == "appwrappers":
309+
assert group == "workload.codeflare.dev"
310+
assert version == "v1beta1"
311+
with open("unit-test-cluster.yaml") as f:
312+
aw = yaml.load(f, Loader=yaml.FullLoader)
313+
assert body == aw
314+
elif plural == "rayclusters":
315+
assert group == "ray.io"
316+
assert version == "v1alpha1"
317+
with open("unit-test-cluster-ray.yaml") as f:
318+
yamls = yaml.load_all(f, Loader=yaml.FullLoader)
319+
for resource in yamls:
320+
if resource["kind"] == "RayCluster":
321+
assert body == resource
322+
elif plural == "routes":
323+
assert group == "route.openshift.io"
324+
assert version == "v1"
325+
with open("unit-test-cluster-ray.yaml") as f:
326+
yamls = yaml.load_all(f, Loader=yaml.FullLoader)
327+
for resource in yamls:
328+
if resource["kind"] == "Route":
329+
assert body == resource
330+
else:
331+
assert 1 == 0
297332

298333

299334
def arg_check_del_effect(group, version, namespace, plural, name, *args):
300-
assert group == "workload.codeflare.dev"
301-
assert version == "v1beta1"
302335
assert namespace == "ns"
303-
assert plural == "appwrappers"
304-
assert name == "unit-test-cluster"
305336
assert args == tuple()
337+
if plural == "appwrappers":
338+
assert group == "workload.codeflare.dev"
339+
assert version == "v1beta1"
340+
assert name == "unit-test-cluster"
341+
elif plural == "rayclusters":
342+
assert group == "ray.io"
343+
assert version == "v1alpha1"
344+
assert name == "unit-test-cluster-ray"
345+
elif plural == "routes":
346+
assert group == "route.openshift.io"
347+
assert version == "v1"
348+
assert name == "ray-dashboard-unit-test-cluster-ray"
306349

307350

308351
def test_cluster_up_down(mocker):
@@ -324,6 +367,47 @@ def test_cluster_up_down(mocker):
324367
cluster.down()
325368

326369

370+
def test_cluster_up_down_no_mcad(mocker):
371+
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
372+
mocker.patch(
373+
"kubernetes.client.CustomObjectsApi.create_namespaced_custom_object",
374+
side_effect=arg_check_apply_effect,
375+
)
376+
mocker.patch(
377+
"kubernetes.client.CustomObjectsApi.delete_namespaced_custom_object",
378+
side_effect=arg_check_del_effect,
379+
)
380+
mocker.patch(
381+
"kubernetes.client.CustomObjectsApi.list_cluster_custom_object",
382+
return_value={"items": []},
383+
)
384+
config = createClusterConfig()
385+
config.name = "unit-test-cluster-ray"
386+
config.mcad = False
387+
cluster = Cluster(config)
388+
cluster.up()
389+
cluster.down()
390+
391+
392+
def arg_check_list_effect(group, version, plural, name, *args):
393+
assert group == "config.openshift.io"
394+
assert version == "v1"
395+
assert plural == "ingresses"
396+
assert name == "cluster"
397+
assert args == tuple()
398+
return {"spec": {"domain": "test"}}
399+
400+
401+
def test_get_ingress_domain(mocker):
402+
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
403+
mocker.patch(
404+
"kubernetes.client.CustomObjectsApi.get_cluster_custom_object",
405+
side_effect=arg_check_list_effect,
406+
)
407+
domain = _get_ingress_domain()
408+
assert domain == "test"
409+
410+
327411
def aw_status_fields(group, version, namespace, plural, *args):
328412
assert group == "workload.codeflare.dev"
329413
assert version == "v1beta1"
@@ -1851,7 +1935,7 @@ def test_DDPJobDefinition_dry_run(mocker: MockerFixture):
18511935
"codeflare_sdk.cluster.cluster.Cluster.cluster_dashboard_uri",
18521936
return_value="",
18531937
)
1854-
mocker.patch.object(Cluster, "client")
1938+
mocker.patch.object(Cluster, "job_client")
18551939
ddp = createTestDDP()
18561940
cluster = createClusterWithConfig()
18571941
ddp_job, _ = ddp._dry_run(cluster)
@@ -1921,7 +2005,7 @@ def test_DDPJobDefinition_dry_run_no_resource_args(mocker):
19212005
Test that the dry run correctly gets resources from the cluster object
19222006
when the job definition does not specify resources.
19232007
"""
1924-
mocker.patch.object(Cluster, "client")
2008+
mocker.patch.object(Cluster, "job_client")
19252009
mocker.patch(
19262010
"codeflare_sdk.cluster.cluster.Cluster.cluster_dashboard_uri",
19272011
return_value="",
@@ -2013,7 +2097,7 @@ def test_DDPJobDefinition_submit(mocker: MockerFixture):
20132097
mock_schedule = MagicMock()
20142098
mocker.patch.object(Runner, "schedule", mock_schedule)
20152099
mock_schedule.return_value = "fake-dashboard-url"
2016-
mocker.patch.object(Cluster, "client")
2100+
mocker.patch.object(Cluster, "job_client")
20172101
ddp_def = createTestDDP()
20182102
cluster = createClusterWithConfig()
20192103
mocker.patch(
@@ -2040,7 +2124,7 @@ def test_DDPJobDefinition_submit(mocker: MockerFixture):
20402124

20412125

20422126
def test_DDPJob_creation(mocker: MockerFixture):
2043-
mocker.patch.object(Cluster, "client")
2127+
mocker.patch.object(Cluster, "job_client")
20442128
mock_schedule = MagicMock()
20452129
mocker.patch.object(Runner, "schedule", mock_schedule)
20462130
mocker.patch.object(
@@ -2432,6 +2516,7 @@ def test_cleanup():
24322516
os.remove("unit-test-cluster.yaml")
24332517
os.remove("prio-test-cluster.yaml")
24342518
os.remove("unit-test-default-cluster.yaml")
2519+
os.remove("unit-test-cluster-ray.yaml")
24352520
os.remove("test.yaml")
24362521
os.remove("raytest2.yaml")
24372522
os.remove("quicktest.yaml")

0 commit comments

Comments
 (0)
Please sign in to comment.