diff --git a/ray-operator/apis/ray/v1/constant.go b/ray-operator/apis/ray/v1/constant.go new file mode 100644 index 0000000000..66982cabac --- /dev/null +++ b/ray-operator/apis/ray/v1/constant.go @@ -0,0 +1,12 @@ +package v1 + +const ( + // In KubeRay, the Ray container must be the first application container in a head or worker Pod. + RayContainerIndex = 0 + + // Use as container env variable + RAY_REDIS_ADDRESS = "RAY_REDIS_ADDRESS" + + // Ray GCS FT related annotations + RayFTEnabledAnnotationKey = "ray.io/ft-enabled" +) diff --git a/ray-operator/apis/ray/v1/raycluster_webhook.go b/ray-operator/apis/ray/v1/raycluster_webhook.go index 6650ef9534..3c4086e955 100644 --- a/ray-operator/apis/ray/v1/raycluster_webhook.go +++ b/ray-operator/apis/ray/v1/raycluster_webhook.go @@ -1,6 +1,7 @@ package v1 import ( + "fmt" "regexp" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -59,6 +60,10 @@ func (r *RayCluster) validateRayCluster() error { allErrs = append(allErrs, err) } + if err := r.ValidateRayClusterSpec(); err != nil { + allErrs = append(allErrs, err) + } + if len(allErrs) == 0 { return nil } @@ -87,3 +92,23 @@ func (r *RayCluster) validateWorkerGroups() *field.Error { return nil } + +func (r *RayCluster) ValidateRayClusterSpec() *field.Error { + if r.Annotations[RayFTEnabledAnnotationKey] == "false" && r.Spec.GcsFaultToleranceOptions != nil { + return field.Invalid( + field.NewPath("spec").Child("gcsFaultToleranceOptions"), + r.Spec.GcsFaultToleranceOptions, + fmt.Sprintf("GcsFaultToleranceOptions should be nil when %s annotation is set to false", RayFTEnabledAnnotationKey), + ) + } + if r.Annotations[RayFTEnabledAnnotationKey] != "true" && len(r.Spec.HeadGroupSpec.Template.Spec.Containers) > 0 { + if EnvVarExists(RAY_REDIS_ADDRESS, r.Spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex].Env) { + return field.Invalid( + field.NewPath("spec").Child("headGroupSpec").Child("template").Child("spec").Child("containers").Index(0).Child("env"), + RAY_REDIS_ADDRESS, + fmt.Sprintf("%s should not be set when %s is disabled", RAY_REDIS_ADDRESS, RayFTEnabledAnnotationKey), + ) + } + } + return nil +} diff --git a/ray-operator/apis/ray/v1/raycluster_webhook_test.go b/ray-operator/apis/ray/v1/raycluster_webhook_test.go new file mode 100644 index 0000000000..a112e2cf1f --- /dev/null +++ b/ray-operator/apis/ray/v1/raycluster_webhook_test.go @@ -0,0 +1,252 @@ +package v1 + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/validation/field" +) + +func TestValidateRayClusterSpec(t *testing.T) { + tests := []struct { + gcsFaultToleranceOptions *GcsFaultToleranceOptions + annotations map[string]string + name string + errorMessage string + envVars []corev1.EnvVar + expectError bool + }{ + { + name: "FT disabled with GcsFaultToleranceOptions set", + annotations: map[string]string{ + RayFTEnabledAnnotationKey: "false", + }, + gcsFaultToleranceOptions: &GcsFaultToleranceOptions{}, + expectError: true, + errorMessage: fmt.Sprintf("GcsFaultToleranceOptions should be nil when %s annotation is set to false", RayFTEnabledAnnotationKey), + }, + { + name: "FT disabled with RAY_REDIS_ADDRESS set", + annotations: map[string]string{ + RayFTEnabledAnnotationKey: "false", + }, + envVars: []corev1.EnvVar{ + { + Name: RAY_REDIS_ADDRESS, + Value: "redis://127.0.0.1:6379", + }, + }, + expectError: true, + errorMessage: fmt.Sprintf("%s should not be set when %s is disabled", RAY_REDIS_ADDRESS, RayFTEnabledAnnotationKey), + }, + { + name: "FT not set with RAY_REDIS_ADDRESS set", + annotations: map[string]string{}, + envVars: []corev1.EnvVar{ + { + Name: RAY_REDIS_ADDRESS, + Value: "redis://127.0.0.1:6379", + }, + }, + expectError: true, + errorMessage: fmt.Sprintf("%s should not be set when %s is disabled", RAY_REDIS_ADDRESS, RayFTEnabledAnnotationKey), + }, + { + name: "FT disabled with other environment variables set", + annotations: map[string]string{ + RayFTEnabledAnnotationKey: "false", + }, + envVars: []corev1.EnvVar{ + { + Name: "SOME_OTHER_ENV", + Value: "some-value", + }, + }, + expectError: false, + }, + { + name: "FT enabled, GcsFaultToleranceOptions not nil", + annotations: map[string]string{ + RayFTEnabledAnnotationKey: "true", + }, + gcsFaultToleranceOptions: &GcsFaultToleranceOptions{ + RedisAddress: "redis://127.0.0.1:6379", + }, + expectError: false, + }, + { + name: "FT enabled, GcsFaultToleranceOptions is nil", + annotations: map[string]string{ + RayFTEnabledAnnotationKey: "true", + }, + expectError: false, + }, + { + name: "FT enabled with with other environment variables set", + annotations: map[string]string{ + RayFTEnabledAnnotationKey: "true", + }, + envVars: []corev1.EnvVar{ + { + Name: "SOME_OTHER_ENV", + Value: "some-value", + }, + }, + expectError: false, + }, + { + name: "FT enabled with RAY_REDIS_ADDRESS set", + annotations: map[string]string{ + RayFTEnabledAnnotationKey: "true", + }, + envVars: []corev1.EnvVar{ + { + Name: RAY_REDIS_ADDRESS, + Value: "redis://127.0.0.1:6379", + }, + }, + expectError: false, + }, + { + name: "FT disabled with no GcsFaultToleranceOptions and no RAY_REDIS_ADDRESS", + annotations: map[string]string{ + RayFTEnabledAnnotationKey: "false", + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: tt.annotations, + }, + Spec: RayClusterSpec{ + GcsFaultToleranceOptions: tt.gcsFaultToleranceOptions, + HeadGroupSpec: HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-head", + Env: tt.envVars, + }, + }, + }, + }, + }, + }, + } + err := r.ValidateRayClusterSpec() + if tt.expectError { + assert.NotNil(t, err) + assert.IsType(t, &field.Error{}, err) + assert.Equal(t, err.Detail, tt.errorMessage) + } else { + assert.Nil(t, err) + } + }) + } +} + +func TestValidateRayCluster(t *testing.T) { + tests := []struct { + GcsFaultToleranceOptions *GcsFaultToleranceOptions + name string + errorMessage string + ObjectMeta metav1.ObjectMeta + WorkerGroupSpecs []WorkerGroupSpec + expectError bool + }{ + { + name: "Invalid name", + ObjectMeta: metav1.ObjectMeta{ + Name: "Invalid_Name", + }, + expectError: true, + errorMessage: "name must consist of lower case alphanumeric characters or '-', start with an alphabetic character, and end with an alphanumeric character", + }, + { + name: "Duplicate worker group names", + + ObjectMeta: metav1.ObjectMeta{ + Name: "valid-name", + }, + + WorkerGroupSpecs: []WorkerGroupSpec{ + {GroupName: "group1"}, + {GroupName: "group1"}, + }, + + expectError: true, + errorMessage: "worker group names must be unique", + }, + { + name: "FT disabled with GcsFaultToleranceOptions set", + + ObjectMeta: metav1.ObjectMeta{ + Name: "valid-name", + Annotations: map[string]string{ + RayFTEnabledAnnotationKey: "false", + }, + }, + GcsFaultToleranceOptions: &GcsFaultToleranceOptions{}, + expectError: true, + errorMessage: fmt.Sprintf("GcsFaultToleranceOptions should be nil when %s annotation is set to false", RayFTEnabledAnnotationKey), + }, + { + name: "Valid RayCluster", + + ObjectMeta: metav1.ObjectMeta{ + Name: "valid-name", + Annotations: map[string]string{ + RayFTEnabledAnnotationKey: "true", + }, + }, + GcsFaultToleranceOptions: &GcsFaultToleranceOptions{ + RedisAddress: "redis://127.0.0.1:6379", + }, + WorkerGroupSpecs: []WorkerGroupSpec{ + {GroupName: "group1"}, + {GroupName: "group2"}, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rayCluster := &RayCluster{ + ObjectMeta: tt.ObjectMeta, + Spec: RayClusterSpec{ + GcsFaultToleranceOptions: tt.GcsFaultToleranceOptions, + HeadGroupSpec: HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-head", + }, + }, + }, + }, + }, + WorkerGroupSpecs: tt.WorkerGroupSpecs, + }, + } + err := rayCluster.validateRayCluster() + if tt.expectError { + assert.NotNil(t, err) + assert.IsType(t, &apierrors.StatusError{}, err) + assert.Contains(t, err.Error(), tt.errorMessage) + } else { + assert.Nil(t, err) + } + }) + } +} diff --git a/ray-operator/apis/ray/v1/utils.go b/ray-operator/apis/ray/v1/utils.go new file mode 100644 index 0000000000..59dff80124 --- /dev/null +++ b/ray-operator/apis/ray/v1/utils.go @@ -0,0 +1,12 @@ +package v1 + +import corev1 "k8s.io/api/core/v1" + +func EnvVarExists(envName string, envVars []corev1.EnvVar) bool { + for _, env := range envVars { + if env.Name == envName { + return true + } + } + return false +} diff --git a/ray-operator/controllers/ray/utils/constant.go b/ray-operator/controllers/ray/utils/constant.go index b9c02c2d1d..582a58591a 100644 --- a/ray-operator/controllers/ray/utils/constant.go +++ b/ray-operator/controllers/ray/utils/constant.go @@ -1,6 +1,10 @@ package utils -import "errors" +import ( + "errors" + + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" +) const ( @@ -28,7 +32,7 @@ const ( KubeRayVersion = "ray.io/kuberay-version" // In KubeRay, the Ray container must be the first application container in a head or worker Pod. - RayContainerIndex = 0 + RayContainerIndex = rayv1.RayContainerIndex // Batch scheduling labels // TODO(tgaddair): consider making these part of the CRD @@ -37,7 +41,7 @@ const ( RayClusterGangSchedulingEnabled = "ray.io/gang-scheduling-enabled" // Ray GCS FT related annotations - RayFTEnabledAnnotationKey = "ray.io/ft-enabled" + RayFTEnabledAnnotationKey = rayv1.RayFTEnabledAnnotationKey RayExternalStorageNSAnnotationKey = "ray.io/external-storage-namespace" // If this annotation is set to "true", the KubeRay operator will not modify the container's command. @@ -98,7 +102,7 @@ const ( FQ_RAY_IP = "FQ_RAY_IP" RAY_PORT = "RAY_PORT" RAY_ADDRESS = "RAY_ADDRESS" - RAY_REDIS_ADDRESS = "RAY_REDIS_ADDRESS" + RAY_REDIS_ADDRESS = rayv1.RAY_REDIS_ADDRESS REDIS_PASSWORD = "REDIS_PASSWORD" RAY_DASHBOARD_ENABLE_K8S_DISK_USAGE = "RAY_DASHBOARD_ENABLE_K8S_DISK_USAGE" RAY_EXTERNAL_STORAGE_NS = "RAY_external_storage_namespace" diff --git a/ray-operator/controllers/ray/utils/util.go b/ray-operator/controllers/ray/utils/util.go index 7ea1ba8185..10179cf45d 100644 --- a/ray-operator/controllers/ray/utils/util.go +++ b/ray-operator/controllers/ray/utils/util.go @@ -589,14 +589,7 @@ func IsJobFinished(j *batchv1.Job) (batchv1.JobConditionType, bool) { return "", false } -func EnvVarExists(envName string, envVars []corev1.EnvVar) bool { - for _, env := range envVars { - if env.Name == envName { - return true - } - } - return false -} +var EnvVarExists func(envName string, envVars []corev1.EnvVar) bool = rayv1.EnvVarExists func UpsertEnvVar(envVars []corev1.EnvVar, newEnvVar corev1.EnvVar) []corev1.EnvVar { overridden := false