diff --git a/ray-operator/apis/ray/v1/constant.go b/ray-operator/apis/ray/v1/constant.go new file mode 100644 index 00000000000..d8c25825d6b --- /dev/null +++ b/ray-operator/apis/ray/v1/constant.go @@ -0,0 +1,12 @@ +package v1 + +const ( + // In KubeRay, the Ray container must be the first application container in a head or worker Pod. + RayContainerIndex = 0 + + // Use as container env variable + RAY_REDIS_ADDRESS = "RAY_REDIS_ADDRESS" + REDIS_PASSWORD = "REDIS_PASSWORD" + // Ray GCS FT related annotations + RayFTEnabledAnnotationKey = "ray.io/ft-enabled" +) diff --git a/ray-operator/apis/ray/v1/pod.go b/ray-operator/apis/ray/v1/pod.go new file mode 100644 index 00000000000..c7ce4289c61 --- /dev/null +++ b/ray-operator/apis/ray/v1/pod.go @@ -0,0 +1,8 @@ +package v1 + +import "strings" + +func IsGCSFaultToleranceEnabled(instance RayCluster) bool { + v, ok := instance.Annotations[RayFTEnabledAnnotationKey] + return (ok && strings.ToLower(v) == "true") || instance.Spec.GcsFaultToleranceOptions != nil +} diff --git a/ray-operator/apis/ray/v1/pod_test.go b/ray-operator/apis/ray/v1/pod_test.go new file mode 100644 index 00000000000..68bf7f1786c --- /dev/null +++ b/ray-operator/apis/ray/v1/pod_test.go @@ -0,0 +1,70 @@ +package v1 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestIsGCSFaultToleranceEnabled(t *testing.T) { + tests := []struct { + name string + instance RayCluster + expected bool + }{ + { + name: "ray.io/ft-enabled is true", + instance: RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + RayFTEnabledAnnotationKey: "true", + }, + }, + }, + expected: true, + }, + { + name: "ray.io/ft-enabled is false", + instance: RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + RayFTEnabledAnnotationKey: "false", + }, + }, + }, + expected: false, + }, + { + name: "ray.io/ft-enabled is nil, GcsFaultToleranceOptions is not nil", + instance: RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{}, + }, + Spec: RayClusterSpec{ + GcsFaultToleranceOptions: &GcsFaultToleranceOptions{}, + }, + }, + expected: true, + }, + { + name: "ray.io/ft-enabled is nil, GcsFaultToleranceOptions is nil", + instance: RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{}, + }, + Spec: RayClusterSpec{ + GcsFaultToleranceOptions: nil, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsGCSFaultToleranceEnabled(tt.instance) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/ray-operator/apis/ray/v1/raycluster_webhook.go b/ray-operator/apis/ray/v1/raycluster_webhook.go index 6650ef9534f..dc77aa125e6 100644 --- a/ray-operator/apis/ray/v1/raycluster_webhook.go +++ b/ray-operator/apis/ray/v1/raycluster_webhook.go @@ -1,6 +1,7 @@ package v1 import ( + "fmt" "regexp" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -59,6 +60,10 @@ func (r *RayCluster) validateRayCluster() error { allErrs = append(allErrs, err) } + if err := r.ValidateRayClusterSpec(); err != nil { + allErrs = append(allErrs, err) + } + if len(allErrs) == 0 { return nil } @@ -87,3 +92,68 @@ func (r *RayCluster) validateWorkerGroups() *field.Error { return nil } + +func (r *RayCluster) ValidateRayClusterSpec() *field.Error { + if len(r.Spec.HeadGroupSpec.Template.Spec.Containers) == 0 { + return field.Invalid( + field.NewPath("spec").Child("headGroupSpec").Child("template").Child("spec").Child("containers"), + r.Spec.HeadGroupSpec.Template.Spec.Containers, + "headGroupSpec should have at least one container", + ) + } + + for i, workerGroup := range r.Spec.WorkerGroupSpecs { + if len(workerGroup.Template.Spec.Containers) == 0 { + return field.Invalid( + field.NewPath("spec").Child("workerGroupSpecs").Index(i), + workerGroup, + "workerGroupSpec should have at least one container", + ) + } + } + + if r.Annotations[RayFTEnabledAnnotationKey] != "" && r.Spec.GcsFaultToleranceOptions != nil { + return field.Invalid( + field.NewPath("metadata").Child("annotations").Child(RayFTEnabledAnnotationKey), + r.Annotations[RayFTEnabledAnnotationKey], + fmt.Sprintf("%s annotation and GcsFaultToleranceOptions are both set. "+ + "Please use only GcsFaultToleranceOptions to configure GCS fault tolerance", RayFTEnabledAnnotationKey), + ) + } + + if !IsGCSFaultToleranceEnabled(*r) { + if EnvVarExists(RAY_REDIS_ADDRESS, r.Spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex].Env) { + return field.Invalid( + field.NewPath("spec").Child("headGroupSpec").Child("template").Child("spec").Child("containers").Index(RayContainerIndex).Child("env"), + r.Spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex].Env, + fmt.Sprintf("%s is set which implicitly enables GCS fault tolerance, "+ + "but GcsFaultToleranceOptions is not set. Please set GcsFaultToleranceOptions "+ + "to enable GCS fault tolerance", RAY_REDIS_ADDRESS), + ) + } + } + + if r.Spec.GcsFaultToleranceOptions != nil { + if redisPassword := r.Spec.HeadGroupSpec.RayStartParams["redis-password"]; redisPassword != "" { + return field.Invalid( + field.NewPath("spec").Child("headGroupSpec").Child("rayStartParams"), + r.Spec.HeadGroupSpec.RayStartParams, + "cannot set `redis-password` in rayStartParams when GcsFaultToleranceOptions is enabled - use GcsFaultToleranceOptions.RedisPassword instead", + ) + } + + headContainer := r.Spec.HeadGroupSpec.Template.Spec.Containers[RayContainerIndex] + if EnvVarExists(REDIS_PASSWORD, headContainer.Env) { + return field.Invalid( + field.NewPath("spec").Child("headGroupSpec").Child("template").Child("spec").Child("containers").Index(RayContainerIndex).Child("env"), + headContainer.Env, + "cannot set `REDIS_PASSWORD` env var in head Pod when GcsFaultToleranceOptions is enabled - use GcsFaultToleranceOptions.RedisPassword instead", + ) + } + } + + // TODO (kevin85421): If GcsFaultToleranceOptions is set, users should use `GcsFaultToleranceOptions.RedisAddress` instead of `RAY_REDIS_ADDRESS`. + // TODO (kevin85421): If GcsFaultToleranceOptions is set, users should use `GcsFaultToleranceOptions.ExternalStorageNamespace` instead of + // the annotation `ray.io/external-storage-namespace`. + return nil +} diff --git a/ray-operator/apis/ray/v1/raycluster_webhook_test.go b/ray-operator/apis/ray/v1/raycluster_webhook_test.go new file mode 100644 index 00000000000..475afbca5e5 --- /dev/null +++ b/ray-operator/apis/ray/v1/raycluster_webhook_test.go @@ -0,0 +1,401 @@ +package v1 + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/validation/field" +) + +func TestValidateRayClusterSpecGcsFaultToleranceOptions(t *testing.T) { + errorMessageBothSet := fmt.Sprintf("%s annotation and GcsFaultToleranceOptions are both set. "+ + "Please use only GcsFaultToleranceOptions to configure GCS fault tolerance", RayFTEnabledAnnotationKey) + errorMessageRedisAddressSet := fmt.Sprintf("%s is set which implicitly enables GCS fault tolerance, "+ + "but GcsFaultToleranceOptions is not set. Please set GcsFaultToleranceOptions "+ + "to enable GCS fault tolerance", RAY_REDIS_ADDRESS) + + tests := []struct { + gcsFaultToleranceOptions *GcsFaultToleranceOptions + annotations map[string]string + name string + errorMessage string + envVars []corev1.EnvVar + expectError bool + }{ + // GcsFaultToleranceOptions and ray.io/ft-enabled should not be both set. + { + name: "ray.io/ft-enabled is set to false and GcsFaultToleranceOptions is set", + annotations: map[string]string{ + RayFTEnabledAnnotationKey: "false", + }, + gcsFaultToleranceOptions: &GcsFaultToleranceOptions{}, + expectError: true, + errorMessage: errorMessageBothSet, + }, + { + name: "ray.io/ft-enabled is set to true and GcsFaultToleranceOptions is set", + annotations: map[string]string{ + RayFTEnabledAnnotationKey: "true", + }, + gcsFaultToleranceOptions: &GcsFaultToleranceOptions{}, + expectError: true, + errorMessage: errorMessageBothSet, + }, + { + name: "ray.io/ft-enabled is not set and GcsFaultToleranceOptions is set", + gcsFaultToleranceOptions: &GcsFaultToleranceOptions{}, + expectError: false, + }, + { + name: "ray.io/ft-enabled is not set and GcsFaultToleranceOptions is not set", + gcsFaultToleranceOptions: nil, + expectError: false, + }, + // RAY_REDIS_ADDRESS should not be set if KubeRay is not aware that GCS fault tolerance is enabled. + { + name: "ray.io/ft-enabled is set to false and RAY_REDIS_ADDRESS is set", + annotations: map[string]string{ + RayFTEnabledAnnotationKey: "false", + }, + envVars: []corev1.EnvVar{ + { + Name: RAY_REDIS_ADDRESS, + Value: "redis:6379", + }, + }, + expectError: true, + errorMessage: errorMessageRedisAddressSet, + }, + { + name: "ray.io/ft-enabled is not set and RAY_REDIS_ADDRESS is set", + envVars: []corev1.EnvVar{ + { + Name: RAY_REDIS_ADDRESS, + Value: "redis:6379", + }, + }, + expectError: true, + errorMessage: errorMessageRedisAddressSet, + }, + { + name: "ray.io/ft-enabled is set to true and RAY_REDIS_ADDRESS is set", + annotations: map[string]string{ + RayFTEnabledAnnotationKey: "true", + }, + envVars: []corev1.EnvVar{ + { + Name: RAY_REDIS_ADDRESS, + Value: "redis:6379", + }, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rayCluster := &RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: tt.annotations, + }, + Spec: RayClusterSpec{ + GcsFaultToleranceOptions: tt.gcsFaultToleranceOptions, + HeadGroupSpec: HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Env: tt.envVars, + }, + }, + }, + }, + }, + }, + } + err := rayCluster.ValidateRayClusterSpec() + if tt.expectError { + assert.NotNil(t, err) + assert.IsType(t, &field.Error{}, err) + assert.Equal(t, err.Detail, tt.errorMessage) + } else { + assert.Nil(t, err) + } + }) + } +} + +func TestValidateRayClusterSpecRedisPassword(t *testing.T) { + tests := []struct { + gcsFaultToleranceOptions *GcsFaultToleranceOptions + name string + rayStartParams map[string]string + envVars []corev1.EnvVar + expectError bool + }{ + { + name: "GcsFaultToleranceOptions is set and `redis-password` is also set in rayStartParams", + gcsFaultToleranceOptions: &GcsFaultToleranceOptions{}, + rayStartParams: map[string]string{ + "redis-password": "password", + }, + expectError: true, + }, + { + name: "GcsFaultToleranceOptions is set and `REDIS_PASSWORD` env var is also set in the head Pod", + gcsFaultToleranceOptions: &GcsFaultToleranceOptions{}, + envVars: []corev1.EnvVar{ + { + Name: REDIS_PASSWORD, + Value: "password", + }, + }, + expectError: true, + }, + { + name: "GcsFaultToleranceOptions.RedisPassword is set", + gcsFaultToleranceOptions: &GcsFaultToleranceOptions{ + RedisPassword: &RedisCredential{ + Value: "password", + }, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rayCluster := &RayCluster{ + Spec: RayClusterSpec{ + GcsFaultToleranceOptions: tt.gcsFaultToleranceOptions, + HeadGroupSpec: HeadGroupSpec{ + RayStartParams: tt.rayStartParams, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Env: tt.envVars, + }, + }, + }, + }, + }, + }, + } + err := rayCluster.ValidateRayClusterSpec() + if tt.expectError { + assert.NotNil(t, err) + assert.IsType(t, &field.Error{}, err) + } else { + assert.Nil(t, err) + } + }) + } +} + +func TestValidateRayClusterSpecEmptyContainers(t *testing.T) { + headGroupSpecWithOneContainer := HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{Name: "ray-head"}}, + }, + }, + } + workerGroupSpecWithOneContainer := WorkerGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{Name: "ray-worker"}}, + }, + }, + } + headGroupSpecWithNoContainers := *headGroupSpecWithOneContainer.DeepCopy() + headGroupSpecWithNoContainers.Template.Spec.Containers = []corev1.Container{} + workerGroupSpecWithNoContainers := *workerGroupSpecWithOneContainer.DeepCopy() + workerGroupSpecWithNoContainers.Template.Spec.Containers = []corev1.Container{} + + tests := []struct { + rayCluster *RayCluster + name string + errorMessage string + expectError bool + }{ + { + name: "headGroupSpec has no containers", + rayCluster: &RayCluster{ + Spec: RayClusterSpec{ + HeadGroupSpec: headGroupSpecWithNoContainers, + }, + }, + expectError: true, + errorMessage: "headGroupSpec should have at least one container", + }, + { + name: "workerGroupSpec has no containers", + rayCluster: &RayCluster{ + Spec: RayClusterSpec{ + HeadGroupSpec: headGroupSpecWithOneContainer, + WorkerGroupSpecs: []WorkerGroupSpec{workerGroupSpecWithNoContainers}, + }, + }, + expectError: true, + errorMessage: "workerGroupSpec should have at least one container", + }, + { + name: "valid cluster with containers in both head and worker groups", + rayCluster: &RayCluster{ + Spec: RayClusterSpec{ + HeadGroupSpec: headGroupSpecWithOneContainer, + WorkerGroupSpecs: []WorkerGroupSpec{workerGroupSpecWithOneContainer}, + }, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.rayCluster.ValidateRayClusterSpec() + if tt.expectError { + assert.NotNil(t, err) + assert.IsType(t, &field.Error{}, err) + assert.Equal(t, err.Detail, tt.errorMessage) + } else { + assert.Nil(t, err) + } + }) + } +} + +func TestValidateRayCluster(t *testing.T) { + validHeadGroupSpec := HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + {Name: "ray-head"}, + }, + }, + }, + } + workerGroupSpec := WorkerGroupSpec{ + GroupName: "worker-group-1", + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + {Name: "ray-worker"}, + }, + }, + }, + } + workerGroupSpecs := []WorkerGroupSpec{workerGroupSpec} + + tests := []struct { + rayCluster *RayCluster + name string + errorMessage string + expectError bool + }{ + { + name: "valid RayCluster", + rayCluster: &RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "valid-name", + }, + Spec: RayClusterSpec{ + HeadGroupSpec: validHeadGroupSpec, + WorkerGroupSpecs: workerGroupSpecs, + }, + }, + expectError: false, + }, + { + name: "invalid rayCluster name", + rayCluster: &RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "Invalid_Name", + }, + Spec: RayClusterSpec{ + HeadGroupSpec: validHeadGroupSpec, + WorkerGroupSpecs: workerGroupSpecs, + }, + }, + expectError: true, + errorMessage: "name must consist of lower case alphanumeric characters or '-', start with an alphabetic character, and end with an alphanumeric character", + }, + { + name: "duplicate worker group names", + rayCluster: &RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "valid-name", + }, + Spec: RayClusterSpec{ + HeadGroupSpec: validHeadGroupSpec, + WorkerGroupSpecs: []WorkerGroupSpec{ + workerGroupSpec, + workerGroupSpec, + }, + }, + }, + expectError: true, + errorMessage: "worker group names must be unique", + }, + { + name: "headGroupSpec has no containers", + rayCluster: &RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "valid-name", + }, + Spec: RayClusterSpec{ + HeadGroupSpec: HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{}, + }, + }, + }, + WorkerGroupSpecs: workerGroupSpecs, + }, + }, + expectError: true, + errorMessage: "headGroupSpec should have at least one container", + }, + { + name: "workerGroupSpec has no containers", + rayCluster: &RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "valid-name", + }, + Spec: RayClusterSpec{ + HeadGroupSpec: validHeadGroupSpec, + WorkerGroupSpecs: []WorkerGroupSpec{ + { + GroupName: "worker-group-1", + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{}, + }, + }, + }, + }, + }, + }, + expectError: true, + errorMessage: "workerGroupSpec should have at least one container", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.rayCluster.validateRayCluster() + if tt.expectError { + assert.NotNil(t, err) + assert.IsType(t, &apierrors.StatusError{}, err) + assert.Contains(t, err.Error(), tt.errorMessage) + } else { + assert.Nil(t, err) + } + }) + } +} diff --git a/ray-operator/apis/ray/v1/utils.go b/ray-operator/apis/ray/v1/utils.go new file mode 100644 index 00000000000..59dff801249 --- /dev/null +++ b/ray-operator/apis/ray/v1/utils.go @@ -0,0 +1,12 @@ +package v1 + +import corev1 "k8s.io/api/core/v1" + +func EnvVarExists(envName string, envVars []corev1.EnvVar) bool { + for _, env := range envVars { + if env.Name == envName { + return true + } + } + return false +} diff --git a/ray-operator/apis/ray/v1/utils_test.go b/ray-operator/apis/ray/v1/utils_test.go new file mode 100644 index 00000000000..5358cfc9046 --- /dev/null +++ b/ray-operator/apis/ray/v1/utils_test.go @@ -0,0 +1,50 @@ +package v1 + +import ( + "testing" + + corev1 "k8s.io/api/core/v1" +) + +func TestEnvVarExists(t *testing.T) { + tests := []struct { + name string + envName string + envVars []corev1.EnvVar + expected bool + }{ + { + name: "env var exists", + envName: "EXISTING_ENV", + envVars: []corev1.EnvVar{ + {Name: "EXISTING_ENV", Value: "value1"}, + {Name: "ANOTHER_ENV", Value: "value2"}, + }, + expected: true, + }, + { + name: "env var does not exist", + envName: "NON_EXISTING_ENV", + envVars: []corev1.EnvVar{ + {Name: "EXISTING_ENV", Value: "value1"}, + {Name: "ANOTHER_ENV", Value: "value2"}, + }, + expected: false, + }, + { + name: "empty env vars", + envName: "ANY_ENV", + envVars: []corev1.EnvVar{}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := EnvVarExists(tt.envName, tt.envVars) + if result != tt.expected { + t.Errorf("EnvVarExists(%s, %v) = %v; expected %v", tt.envName, tt.envVars, result, tt.expected) + } + }) + } +} diff --git a/ray-operator/apis/ray/v1/webhook_suite_test.go b/ray-operator/apis/ray/v1/webhook_suite_test.go index 52f4e10f402..98f45546135 100644 --- a/ray-operator/apis/ray/v1/webhook_suite_test.go +++ b/ray-operator/apis/ray/v1/webhook_suite_test.go @@ -136,11 +136,25 @@ var _ = Describe("RayCluster validating webhook", func() { RayStartParams: map[string]string{"DEADBEEF": "DEADBEEF"}, Template: corev1.PodTemplateSpec{ Spec: corev1.PodSpec{ - Containers: []corev1.Container{}, + Containers: []corev1.Container{ + {Name: "ray-head"}, + }, + }, + }, + }, + WorkerGroupSpecs: []WorkerGroupSpec{ + { + GroupName: "worker-group-1", + RayStartParams: map[string]string{"DEADBEEF": "DEADBEEF"}, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + {Name: "ray-worker"}, + }, + }, }, }, }, - WorkerGroupSpecs: []WorkerGroupSpec{}, }, } diff --git a/ray-operator/controllers/ray/common/pod.go b/ray-operator/controllers/ray/common/pod.go index facf340211b..94b4d3baf8e 100644 --- a/ray-operator/controllers/ray/common/pod.go +++ b/ray-operator/controllers/ray/common/pod.go @@ -56,10 +56,7 @@ func GetHeadPort(headStartParams map[string]string) string { } // Check if the RayCluster has GCS fault tolerance enabled. -func IsGCSFaultToleranceEnabled(instance rayv1.RayCluster) bool { - v, ok := instance.Annotations[utils.RayFTEnabledAnnotationKey] - return (ok && strings.ToLower(v) == "true") || instance.Spec.GcsFaultToleranceOptions != nil -} +var IsGCSFaultToleranceEnabled func(instance rayv1.RayCluster) bool = rayv1.IsGCSFaultToleranceEnabled // Check if overwrites the container command. func isOverwriteRayContainerCmd(instance rayv1.RayCluster) bool { diff --git a/ray-operator/controllers/ray/raycluster_controller_unit_test.go b/ray-operator/controllers/ray/raycluster_controller_unit_test.go index a9d6a5c316a..5575f587936 100644 --- a/ray-operator/controllers/ray/raycluster_controller_unit_test.go +++ b/ray-operator/controllers/ray/raycluster_controller_unit_test.go @@ -3533,7 +3533,7 @@ func TestValidateRayClusterSpecGcsFaultToleranceOptions(t *testing.T) { errorMessage: errorMessageRedisAddressSet, }, { - name: "FT is disabled and RAY_REDIS_ADDRESS is set", + name: "ray.io/ft-enabled is not set and RAY_REDIS_ADDRESS is set", envVars: []corev1.EnvVar{ { Name: utils.RAY_REDIS_ADDRESS, diff --git a/ray-operator/controllers/ray/utils/constant.go b/ray-operator/controllers/ray/utils/constant.go index e9c97a6e36c..3fe775fce00 100644 --- a/ray-operator/controllers/ray/utils/constant.go +++ b/ray-operator/controllers/ray/utils/constant.go @@ -1,6 +1,10 @@ package utils -import "errors" +import ( + "errors" + + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" +) const ( @@ -28,7 +32,7 @@ const ( KubeRayVersion = "ray.io/kuberay-version" // In KubeRay, the Ray container must be the first application container in a head or worker Pod. - RayContainerIndex = 0 + RayContainerIndex = rayv1.RayContainerIndex // Batch scheduling labels // TODO(tgaddair): consider making these part of the CRD @@ -37,7 +41,7 @@ const ( RayClusterGangSchedulingEnabled = "ray.io/gang-scheduling-enabled" // Ray GCS FT related annotations - RayFTEnabledAnnotationKey = "ray.io/ft-enabled" + RayFTEnabledAnnotationKey = rayv1.RayFTEnabledAnnotationKey RayExternalStorageNSAnnotationKey = "ray.io/external-storage-namespace" // If this annotation is set to "true", the KubeRay operator will not modify the container's command. @@ -95,8 +99,8 @@ const ( FQ_RAY_IP = "FQ_RAY_IP" RAY_PORT = "RAY_PORT" RAY_ADDRESS = "RAY_ADDRESS" - RAY_REDIS_ADDRESS = "RAY_REDIS_ADDRESS" - REDIS_PASSWORD = "REDIS_PASSWORD" + RAY_REDIS_ADDRESS = rayv1.RAY_REDIS_ADDRESS + REDIS_PASSWORD = rayv1.REDIS_PASSWORD RAY_DASHBOARD_ENABLE_K8S_DISK_USAGE = "RAY_DASHBOARD_ENABLE_K8S_DISK_USAGE" RAY_EXTERNAL_STORAGE_NS = "RAY_external_storage_namespace" RAY_GCS_RPC_SERVER_RECONNECT_TIMEOUT_S = "RAY_gcs_rpc_server_reconnect_timeout_s" diff --git a/ray-operator/controllers/ray/utils/util.go b/ray-operator/controllers/ray/utils/util.go index 7ea1ba8185e..10179cf45d6 100644 --- a/ray-operator/controllers/ray/utils/util.go +++ b/ray-operator/controllers/ray/utils/util.go @@ -589,14 +589,7 @@ func IsJobFinished(j *batchv1.Job) (batchv1.JobConditionType, bool) { return "", false } -func EnvVarExists(envName string, envVars []corev1.EnvVar) bool { - for _, env := range envVars { - if env.Name == envName { - return true - } - } - return false -} +var EnvVarExists func(envName string, envVars []corev1.EnvVar) bool = rayv1.EnvVarExists func UpsertEnvVar(envVars []corev1.EnvVar, newEnvVar corev1.EnvVar) []corev1.EnvVar { overridden := false