diff --git a/ray-operator/apis/ray/v1/constant.go b/ray-operator/apis/ray/v1/constant.go new file mode 100644 index 00000000000..fedf81a91f3 --- /dev/null +++ b/ray-operator/apis/ray/v1/constant.go @@ -0,0 +1,10 @@ +package v1 + +// In KubeRay, the Ray container must be the first application container in a head or worker Pod. +const RayContainerIndex = 0 + +// Use as container env variable +const RAY_REDIS_ADDRESS = "RAY_REDIS_ADDRESS" + +// Ray GCS FT related annotations +const RayFTEnabledAnnotationKey = "ray.io/ft-enabled" diff --git a/ray-operator/apis/ray/v1/raycluster_webhook.go b/ray-operator/apis/ray/v1/raycluster_webhook.go index 1f9ac2f9022..37c663136a0 100644 --- a/ray-operator/apis/ray/v1/raycluster_webhook.go +++ b/ray-operator/apis/ray/v1/raycluster_webhook.go @@ -1,6 +1,7 @@ package v1 import ( + "fmt" "regexp" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -19,8 +20,6 @@ var ( nameRegex, _ = regexp.Compile("^[a-z]([-a-z0-9]*[a-z0-9])?$") ) -const RayFTEnabledAnnotationKey = "ray.io/ft-enabled" - func (r *RayCluster) SetupWebhookWithManager(mgr ctrl.Manager) error { return ctrl.NewWebhookManagedBy(mgr). For(r). @@ -96,13 +95,22 @@ func (r *RayCluster) validateWorkerGroups() *field.Error { func (r *RayCluster) ValidateRayClusterSpec() *field.Error { if r.Annotations[RayFTEnabledAnnotationKey] == "false" && r.Spec.GcsFaultToleranceOptions != nil { - return field.Invalid(field.NewPath("spec").Child("gcsFaultToleranceOptions"), r.Spec.GcsFaultToleranceOptions, "GcsFaultToleranceOptions should be nil when ray.io/ft-enabled is disabled") + return field.Invalid( + field.NewPath("spec").Child("gcsFaultToleranceOptions"), + r.Spec.GcsFaultToleranceOptions, + fmt.Sprintf("GcsFaultToleranceOptions should be nil when %s is disabled", RayFTEnabledAnnotationKey), + ) } - if r.Annotations[RayFTEnabledAnnotationKey] != "true" && r.Spec.HeadGroupSpec.Template.Spec.Containers[0].Env != nil { - for _, env := range r.Spec.HeadGroupSpec.Template.Spec.Containers[0].Env { - if env.Name == "RAY_REDIS_ADDRESS" { - return field.Invalid(field.NewPath("spec").Child("headGroupSpec").Child("template").Child("spec").Child("containers").Index(0).Child("env"), env.Name, "RAY_REDIS_ADDRESS should not be set when ray.io/ft-enabled is disabled") - } + if r.Annotations[RayFTEnabledAnnotationKey] != "true" && + len(r.Spec.HeadGroupSpec.Template.Spec.Containers) > 0 && + r.Spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex].Env != nil { + + if EnvVarExists(RAY_REDIS_ADDRESS, r.Spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex].Env) { + return field.Invalid( + field.NewPath("spec").Child("headGroupSpec").Child("template").Child("spec").Child("containers").Index(0).Child("env"), + RAY_REDIS_ADDRESS, + fmt.Sprintf("%s should not be set when %s is disabled", RAY_REDIS_ADDRESS, RayFTEnabledAnnotationKey), + ) } } return nil diff --git a/ray-operator/apis/ray/v1/raycluster_webhook_test.go b/ray-operator/apis/ray/v1/raycluster_webhook_test.go index c30f62dba9f..f6150bc9166 100644 --- a/ray-operator/apis/ray/v1/raycluster_webhook_test.go +++ b/ray-operator/apis/ray/v1/raycluster_webhook_test.go @@ -1,6 +1,7 @@ package v1 import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -26,7 +27,7 @@ func TestValidateRayClusterSpec(t *testing.T) { }, gcsFaultToleranceOptions: &GcsFaultToleranceOptions{}, expectError: true, - errorMessage: "GcsFaultToleranceOptions should be nil when ray.io/ft-enabled is disabled", + errorMessage: fmt.Sprintf("GcsFaultToleranceOptions should be nil when %s is disabled", RayFTEnabledAnnotationKey), }, { name: "FT disabled with RAY_REDIS_ADDRESS set", @@ -35,24 +36,24 @@ func TestValidateRayClusterSpec(t *testing.T) { }, envVars: []corev1.EnvVar{ { - Name: "RAY_REDIS_ADDRESS", + Name: RAY_REDIS_ADDRESS, Value: "redis://127.0.0.1:6379", }, }, expectError: true, - errorMessage: "RAY_REDIS_ADDRESS should not be set when ray.io/ft-enabled is disabled", + errorMessage: fmt.Sprintf("%s should not be set when %s is disabled", RAY_REDIS_ADDRESS, RayFTEnabledAnnotationKey), }, { name: "FT not set with RAY_REDIS_ADDRESS set", annotations: map[string]string{}, envVars: []corev1.EnvVar{ { - Name: "RAY_REDIS_ADDRESS", + Name: RAY_REDIS_ADDRESS, Value: "redis://127.0.0.1:6379", }, }, expectError: true, - errorMessage: "RAY_REDIS_ADDRESS should not be set when ray.io/ft-enabled is disabled", + errorMessage: fmt.Sprintf("%s should not be set when %s is disabled", RAY_REDIS_ADDRESS, RayFTEnabledAnnotationKey), }, { name: "FT disabled with other environment variables set", @@ -104,7 +105,7 @@ func TestValidateRayClusterSpec(t *testing.T) { }, envVars: []corev1.EnvVar{ { - Name: "RAY_REDIS_ADDRESS", + Name: RAY_REDIS_ADDRESS, Value: "redis://127.0.0.1:6379", }, }, @@ -196,7 +197,7 @@ func TestValidateRayCluster(t *testing.T) { }, GcsFaultToleranceOptions: &GcsFaultToleranceOptions{}, expectError: true, - errorMessage: "GcsFaultToleranceOptions should be nil when ray.io/ft-enabled is disabled", + errorMessage: fmt.Sprintf("GcsFaultToleranceOptions should be nil when %s is disabled", RayFTEnabledAnnotationKey), }, { name: "Valid RayCluster", diff --git a/ray-operator/apis/ray/v1/utils.go b/ray-operator/apis/ray/v1/utils.go new file mode 100644 index 00000000000..59dff801249 --- /dev/null +++ b/ray-operator/apis/ray/v1/utils.go @@ -0,0 +1,12 @@ +package v1 + +import corev1 "k8s.io/api/core/v1" + +func EnvVarExists(envName string, envVars []corev1.EnvVar) bool { + for _, env := range envVars { + if env.Name == envName { + return true + } + } + return false +}