Skip to content

Commit 604c162

Browse files
committed
Check Pod Labels first and add unit tests
Signed-off-by: Ryan O'Leary <[email protected]>
1 parent e9bd3e1 commit 604c162

File tree

2 files changed

+263
-59
lines changed

2 files changed

+263
-59
lines changed

ray-operator/controllers/ray/common/pod.go

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,34 +1062,68 @@ func addDefaultRayNodeLabels(pod *corev1.Pod) {
10621062
})
10631063
}
10641064
if !containsEnvVar(*rayContainer, utils.RayNodeZone) {
1065-
// uses downward api to set the ray.io/availability-zone node label
1066-
// Ref: https://kubernetes.io/docs/reference/labels-annotations-taints/#topologykubernetesiozone
1067-
envVars = append(envVars, corev1.EnvVar{
1068-
Name: utils.RayNodeZone,
1069-
ValueFrom: &corev1.EnvVarSource{
1070-
FieldRef: &corev1.ObjectFieldSelector{
1071-
FieldPath: fmt.Sprintf("metadata.labels['%s']", utils.K8sTopologyZoneLabel),
1072-
},
1073-
},
1074-
})
1065+
envVars = append(envVars, getPodZoneEnvVar(pod))
10751066
}
10761067
if !containsEnvVar(*rayContainer, utils.RayNodeRegion) {
1077-
// uses downward api to set the ray.io/availability-region node label
1078-
// Ref: https://kubernetes.io/docs/reference/labels-annotations-taints/#topologykubernetesioregion
1079-
envVars = append(envVars, corev1.EnvVar{
1080-
Name: utils.RayNodeRegion,
1081-
ValueFrom: &corev1.EnvVarSource{
1082-
FieldRef: &corev1.ObjectFieldSelector{
1083-
FieldPath: fmt.Sprintf("metadata.labels['%s']", utils.K8sTopologyRegionLabel),
1084-
},
1085-
},
1086-
})
1068+
envVars = append(envVars, getPodRegionEnvVar(pod))
10871069
}
10881070
rayContainer.Env = envVars
10891071
}
10901072

1091-
// getPodMarketTypeFromNodeSelector is a helper function to determine the ray.io/market-type label
1092-
// based on a Kubernetes Pod spec.
1073+
// getPodZoneEnvVar is a helper function to determine the ray.io/availability-zone label value
1074+
// based on a Pod spec - checking labels, nodeSelectors, and then falling back to downward API.
1075+
func getPodZoneEnvVar(pod *corev1.Pod) corev1.EnvVar {
1076+
if podZone, ok := pod.Labels[utils.K8sTopologyZoneLabel]; ok && podZone != "" {
1077+
return corev1.EnvVar{
1078+
Name: utils.RayNodeZone,
1079+
Value: podZone,
1080+
}
1081+
} else if podZone, ok := pod.Spec.NodeSelector[utils.K8sTopologyZoneLabel]; ok && podZone != "" {
1082+
return corev1.EnvVar{
1083+
Name: utils.RayNodeZone,
1084+
Value: podZone,
1085+
}
1086+
}
1087+
// uses downward api to set the ray.io/availability-zone node label
1088+
// Ref: https://kubernetes.io/docs/reference/labels-annotations-taints/#topologykubernetesiozone
1089+
return corev1.EnvVar{
1090+
Name: utils.RayNodeZone,
1091+
ValueFrom: &corev1.EnvVarSource{
1092+
FieldRef: &corev1.ObjectFieldSelector{
1093+
FieldPath: fmt.Sprintf("metadata.labels['%s']", utils.K8sTopologyZoneLabel),
1094+
},
1095+
},
1096+
}
1097+
}
1098+
1099+
// getPodRegionEnvVar is a helper function to determine the ray.io/availability-region label value
1100+
// based on a Pod spec - checking labels, nodeSelectors, and then falling back to downward API.
1101+
func getPodRegionEnvVar(pod *corev1.Pod) corev1.EnvVar {
1102+
if podRegion, ok := pod.Labels[utils.K8sTopologyRegionLabel]; ok && podRegion != "" {
1103+
return corev1.EnvVar{
1104+
Name: utils.RayNodeRegion,
1105+
Value: podRegion,
1106+
}
1107+
} else if podRegion, ok := pod.Spec.NodeSelector[utils.K8sTopologyRegionLabel]; ok && podRegion != "" {
1108+
return corev1.EnvVar{
1109+
Name: utils.RayNodeRegion,
1110+
Value: podRegion,
1111+
}
1112+
}
1113+
// uses downward api to set the ray.io/availability-region node label
1114+
// Ref: https://kubernetes.io/docs/reference/labels-annotations-taints/#topologykubernetesioregion
1115+
return corev1.EnvVar{
1116+
Name: utils.RayNodeRegion,
1117+
ValueFrom: &corev1.EnvVarSource{
1118+
FieldRef: &corev1.ObjectFieldSelector{
1119+
FieldPath: fmt.Sprintf("metadata.labels['%s']", utils.K8sTopologyRegionLabel),
1120+
},
1121+
},
1122+
}
1123+
}
1124+
1125+
// getPodMarketTypeFromNodeSelector is a helper function to determine the ray.io/market-type
1126+
// label value based on a Kubernetes Pod spec - checking labels, nodeSelector, and nodeAffinity.
10931127
func getPodMarketType(pod *corev1.Pod) utils.PodMarketType {
10941128
marketType := getPodMarketTypeFromNodeSelector(pod.Spec.NodeSelector)
10951129

ray-operator/controllers/ray/common/pod_test.go

Lines changed: 207 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2111,54 +2111,224 @@ func TestGetPodMarketType(t *testing.T) {
21112111
}
21122112
}
21132113

2114-
func TestAddDefaultRayNodeLabels_GKESpot(t *testing.T) {
2115-
pod := &corev1.Pod{
2116-
ObjectMeta: metav1.ObjectMeta{
2117-
Labels: map[string]string{
2118-
"ray.io/group": "test-worker-group-1",
2119-
"topology.kubernetes.io/region": "us-central2",
2120-
"topology.kubernetes.io/zone": "us-central2-b",
2114+
func TestAddDefaultRayNodeLabels(t *testing.T) {
2115+
tests := []struct {
2116+
labels map[string]string
2117+
nodeSelector map[string]string
2118+
nodeAffinity *corev1.NodeAffinity
2119+
expectedEnv map[string]string
2120+
name string
2121+
}{
2122+
{
2123+
name: "Availability zone vars set from region and zone topology labels",
2124+
labels: map[string]string{
2125+
utils.K8sTopologyRegionLabel: "us-west4",
2126+
utils.K8sTopologyZoneLabel: "us-west4-a",
2127+
},
2128+
expectedEnv: map[string]string{
2129+
utils.RayNodeRegion: "us-west4",
2130+
utils.RayNodeZone: "us-west4-a",
21212131
},
21222132
},
2123-
Spec: corev1.PodSpec{
2124-
Containers: []corev1.Container{
2125-
{Name: "ray"},
2133+
{
2134+
name: "Availability zone vars set from region and zone topology nodeSelectors",
2135+
nodeSelector: map[string]string{
2136+
utils.K8sTopologyRegionLabel: "us-central2",
2137+
utils.K8sTopologyZoneLabel: "us-central2-b",
21262138
},
2127-
NodeSelector: map[string]string{
2128-
"cloud.google.com/gke-spot": "true",
2139+
expectedEnv: map[string]string{
2140+
utils.RayNodeRegion: "us-central2",
2141+
utils.RayNodeZone: "us-central2-b",
2142+
},
2143+
},
2144+
{
2145+
name: "Availability zone vars set from downward API",
2146+
expectedEnv: map[string]string{
2147+
utils.RayNodeRegion: "metadata.labels['topology.kubernetes.io/region']",
2148+
utils.RayNodeZone: "metadata.labels['topology.kubernetes.io/zone']",
2149+
},
2150+
},
2151+
{
2152+
name: "Market type env var set from GKE Spot nodeSelector",
2153+
nodeSelector: map[string]string{
2154+
utils.GKESpotLabel: "true",
2155+
utils.K8sTopologyRegionLabel: "me-central1",
2156+
utils.K8sTopologyZoneLabel: "me-central1-a",
2157+
},
2158+
expectedEnv: map[string]string{
2159+
utils.RayNodeMarketType: string(utils.SpotMarketType),
2160+
utils.RayNodeRegion: "me-central1",
2161+
utils.RayNodeZone: "me-central1-a",
2162+
},
2163+
},
2164+
{
2165+
name: "Market type env var set from EKS Spot nodeSelector",
2166+
nodeSelector: map[string]string{
2167+
utils.EKSCapacityTypeLabel: "SPOT",
2168+
utils.K8sTopologyRegionLabel: "us-central1",
2169+
utils.K8sTopologyZoneLabel: "us-central1-c",
2170+
},
2171+
expectedEnv: map[string]string{
2172+
utils.RayNodeMarketType: string(utils.SpotMarketType),
2173+
utils.RayNodeRegion: "us-central1",
2174+
utils.RayNodeZone: "us-central1-c",
2175+
},
2176+
},
2177+
{
2178+
name: "Market type env var set from nodeAffinity",
2179+
nodeAffinity: &corev1.NodeAffinity{
2180+
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
2181+
NodeSelectorTerms: []corev1.NodeSelectorTerm{
2182+
{
2183+
MatchExpressions: []corev1.NodeSelectorRequirement{
2184+
{
2185+
Key: utils.EKSCapacityTypeLabel,
2186+
Operator: corev1.NodeSelectorOpIn,
2187+
Values: []string{"SPOT"},
2188+
},
2189+
},
2190+
},
2191+
},
2192+
},
2193+
},
2194+
expectedEnv: map[string]string{
2195+
utils.RayNodeMarketType: string(utils.SpotMarketType),
2196+
utils.RayNodeRegion: "metadata.labels['topology.kubernetes.io/region']",
2197+
utils.RayNodeZone: "metadata.labels['topology.kubernetes.io/zone']",
21292198
},
21302199
},
21312200
}
21322201

2133-
addDefaultRayNodeLabels(pod)
2134-
rayContainer := pod.Spec.Containers[utils.RayContainerIndex]
2135-
checkContainerEnv(t, rayContainer, "RAY_NODE_MARKET_TYPE", "spot")
2136-
checkContainerEnv(t, rayContainer, "RAY_NODE_REGION", "metadata.labels['topology.kubernetes.io/region']")
2137-
checkContainerEnv(t, rayContainer, "RAY_NODE_ZONE", "metadata.labels['topology.kubernetes.io/zone']")
2202+
for _, tt := range tests {
2203+
t.Run(tt.name, func(t *testing.T) {
2204+
pod := &corev1.Pod{
2205+
ObjectMeta: metav1.ObjectMeta{
2206+
Labels: tt.labels,
2207+
},
2208+
Spec: corev1.PodSpec{
2209+
Containers: []corev1.Container{{Name: "ray"}},
2210+
NodeSelector: tt.nodeSelector,
2211+
},
2212+
}
2213+
if tt.nodeAffinity != nil {
2214+
pod.Spec.Affinity = &corev1.Affinity{NodeAffinity: tt.nodeAffinity}
2215+
}
2216+
// validate default labels are set correctly from Pod spec as env vars
2217+
addDefaultRayNodeLabels(pod)
2218+
rayContainer := pod.Spec.Containers[utils.RayContainerIndex]
2219+
for key, expectedVar := range tt.expectedEnv {
2220+
foundVar := false
2221+
for _, env := range rayContainer.Env {
2222+
if env.Name == key {
2223+
if env.Value != "" {
2224+
if env.Value != expectedVar {
2225+
t.Errorf("%s: got value %q, but expected %q", key, env.Value, expectedVar)
2226+
}
2227+
} else if env.ValueFrom != nil && env.ValueFrom.FieldRef != nil {
2228+
if env.ValueFrom.FieldRef.FieldPath != expectedVar {
2229+
t.Errorf("%s: got FieldPath %q, but expected %q", key, env.ValueFrom.FieldRef.FieldPath, expectedVar)
2230+
}
2231+
} else {
2232+
t.Errorf("%s: environment var not set as expected", key)
2233+
}
2234+
foundVar = true
2235+
break
2236+
}
2237+
}
2238+
if !foundVar {
2239+
t.Errorf("%s: not found in container env", key)
2240+
}
2241+
}
2242+
})
2243+
}
21382244
}
21392245

2140-
func TestAddDefaultRayNodeLabels_EKSSpot(t *testing.T) {
2141-
pod := &corev1.Pod{
2142-
ObjectMeta: metav1.ObjectMeta{
2143-
Labels: map[string]string{
2144-
"ray.io/group": "test-worker-group-2",
2145-
"topology.kubernetes.io/region": "us-west4",
2146-
"topology.kubernetes.io/zone": "us-west4-a",
2147-
},
2246+
func TestGetPodZoneEnvVar(t *testing.T) {
2247+
tests := []struct {
2248+
name string
2249+
labels map[string]string
2250+
nodeSelector map[string]string
2251+
expectedVar string
2252+
}{
2253+
{
2254+
name: "Retrieve topology zone from labels",
2255+
labels: map[string]string{utils.K8sTopologyZoneLabel: "us-west4-a"},
2256+
expectedVar: "us-west4-a",
21482257
},
2149-
Spec: corev1.PodSpec{
2150-
Containers: []corev1.Container{
2151-
{Name: "ray"},
2152-
},
2153-
NodeSelector: map[string]string{
2154-
"eks.amazonaws.com/capacityType": "SPOT",
2155-
},
2258+
{
2259+
name: "Retrieve topology zone from nodeSelector",
2260+
nodeSelector: map[string]string{utils.K8sTopologyZoneLabel: "us-central2-b"},
2261+
expectedVar: "us-central2-b",
21562262
},
2263+
{
2264+
name: "Zone set using downward API",
2265+
expectedVar: "metadata.labels['topology.kubernetes.io/zone']",
2266+
},
2267+
}
2268+
for _, tt := range tests {
2269+
t.Run(tt.name, func(t *testing.T) {
2270+
pod := &corev1.Pod{
2271+
ObjectMeta: metav1.ObjectMeta{Labels: tt.labels},
2272+
Spec: corev1.PodSpec{NodeSelector: tt.nodeSelector},
2273+
}
2274+
// validate expected zone env var is parsed from Pod spec
2275+
result := getPodZoneEnvVar(pod)
2276+
if result.Value != "" {
2277+
if result.Value != tt.expectedVar {
2278+
t.Errorf("got env var %q, but expected %q", result.Value, tt.expectedVar)
2279+
}
2280+
} else if result.ValueFrom != nil {
2281+
if result.ValueFrom.FieldRef.FieldPath != tt.expectedVar {
2282+
t.Errorf("got FieldPath %q, but expected %q", result.ValueFrom.FieldRef.FieldPath, tt.expectedVar)
2283+
}
2284+
} else {
2285+
t.Errorf("getPodZoneEnvVar did not return expected env value")
2286+
}
2287+
})
21572288
}
2289+
}
21582290

2159-
addDefaultRayNodeLabels(pod)
2160-
rayContainer := pod.Spec.Containers[utils.RayContainerIndex]
2161-
checkContainerEnv(t, rayContainer, utils.RayNodeMarketType, "spot")
2162-
checkContainerEnv(t, rayContainer, utils.RayNodeRegion, "metadata.labels['topology.kubernetes.io/region']")
2163-
checkContainerEnv(t, rayContainer, utils.RayNodeZone, "metadata.labels['topology.kubernetes.io/zone']")
2291+
func TestGetPodRegionEnvVar(t *testing.T) {
2292+
tests := []struct {
2293+
name string
2294+
labels map[string]string
2295+
nodeSelector map[string]string
2296+
expectedVar string
2297+
}{
2298+
{
2299+
name: "Retrieve topology region from labels",
2300+
labels: map[string]string{utils.K8sTopologyRegionLabel: "us-central1"},
2301+
expectedVar: "us-central1",
2302+
},
2303+
{
2304+
name: "Retrieve topology region from nodeSelector",
2305+
nodeSelector: map[string]string{utils.K8sTopologyRegionLabel: "us-central2"},
2306+
expectedVar: "us-central2",
2307+
},
2308+
{
2309+
name: "Region set using downward API",
2310+
expectedVar: "metadata.labels['topology.kubernetes.io/region']",
2311+
},
2312+
}
2313+
for _, tt := range tests {
2314+
t.Run(tt.name, func(t *testing.T) {
2315+
pod := &corev1.Pod{
2316+
ObjectMeta: metav1.ObjectMeta{Labels: tt.labels},
2317+
Spec: corev1.PodSpec{NodeSelector: tt.nodeSelector},
2318+
}
2319+
// validate expected region env var is parsed from Pod spec
2320+
result := getPodRegionEnvVar(pod)
2321+
if result.Value != "" {
2322+
if result.Value != tt.expectedVar {
2323+
t.Errorf("got env var %q, but expected %q", result.Value, tt.expectedVar)
2324+
}
2325+
} else if result.ValueFrom != nil {
2326+
if result.ValueFrom.FieldRef.FieldPath != tt.expectedVar {
2327+
t.Errorf("got FieldPath %q, but expected %q", result.ValueFrom.FieldRef.FieldPath, tt.expectedVar)
2328+
}
2329+
} else {
2330+
t.Errorf("getPodRegionEnvVar did not return expected env value")
2331+
}
2332+
})
2333+
}
21642334
}

0 commit comments

Comments
 (0)