From 8ca03df2828aae340a28759d75a1f3e5bba78f2c Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Sun, 3 Nov 2024 15:14:31 +0900 Subject: [PATCH] KEP-2170: Implement TrainJob conditions Signed-off-by: Yuki Iwai --- hack/violation_exception_v2alpha1.list | 1 - .../v2/base/crds/kubeflow.org_trainjobs.yaml | 3 + .../v2alpha1/openapi_generated.go | 10 ++ .../kubeflow.org/v2alpha1/trainjob_types.go | 45 ++++- pkg/controller.v2/trainjob_controller.go | 124 ++++++++++++-- pkg/runtime.v2/core/clustertrainingruntime.go | 5 + pkg/runtime.v2/core/trainingruntime.go | 5 + pkg/runtime.v2/framework/core/framework.go | 19 +++ .../framework/core/framework_test.go | 115 ++++++++++++- pkg/runtime.v2/framework/interface.go | 6 + .../plugins/coscheduling/coscheduling.go | 2 +- .../framework/plugins/jobset/jobset.go | 19 +++ pkg/runtime.v2/interface.go | 2 + pkg/util.v2/testing/wrapper.go | 7 + .../controller.v2/trainjob_controller_test.go | 156 +++++++++++++++++- test/integration/framework/framework.go | 4 +- test/util/constants.go | 3 + 17 files changed, 503 insertions(+), 23 deletions(-) diff --git a/hack/violation_exception_v2alpha1.list b/hack/violation_exception_v2alpha1.list index b636df625d..c18c46f8b9 100644 --- a/hack/violation_exception_v2alpha1.list +++ b/hack/violation_exception_v2alpha1.list @@ -13,7 +13,6 @@ API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/ API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,PodSpecOverride,Volumes API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TorchElasticPolicy,Metrics API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TrainJobSpec,PodSpecOverrides -API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TrainJobStatus,Conditions API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TrainJobStatus,JobsStatus API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,Trainer,Args API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,Trainer,Command diff --git a/manifests/v2/base/crds/kubeflow.org_trainjobs.yaml b/manifests/v2/base/crds/kubeflow.org_trainjobs.yaml index ed6cda3760..a0ae3ef0ff 100644 --- a/manifests/v2/base/crds/kubeflow.org_trainjobs.yaml +++ b/manifests/v2/base/crds/kubeflow.org_trainjobs.yaml @@ -3055,6 +3055,9 @@ spec: - type type: object type: array + x-kubernetes-list-map-keys: + - type + x-kubernetes-list-type: map jobsStatus: description: JobsStatus tracks the child Jobs in TrainJob. items: diff --git a/pkg/apis/kubeflow.org/v2alpha1/openapi_generated.go b/pkg/apis/kubeflow.org/v2alpha1/openapi_generated.go index 5394285cda..d4079da74f 100644 --- a/pkg/apis/kubeflow.org/v2alpha1/openapi_generated.go +++ b/pkg/apis/kubeflow.org/v2alpha1/openapi_generated.go @@ -1110,6 +1110,16 @@ func schema_pkg_apis_kubefloworg_v2alpha1_TrainJobStatus(ref common.ReferenceCal Type: []string{"object"}, Properties: map[string]spec.Schema{ "conditions": { + VendorExtensible: spec.VendorExtensible{ + Extensions: spec.Extensions{ + "x-kubernetes-list-map-keys": []interface{}{ + "type", + }, + "x-kubernetes-list-type": "map", + "x-kubernetes-patch-merge-key": "type", + "x-kubernetes-patch-strategy": "merge", + }, + }, SchemaProps: spec.SchemaProps{ Description: "Conditions for the TrainJob.", Type: []string{"array"}, diff --git a/pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go b/pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go index 55a813350e..80281b926f 100644 --- a/pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go +++ b/pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go @@ -48,6 +48,43 @@ type TrainJob struct { Status TrainJobStatus `json:"status,omitempty"` } +const ( + // TrainJobSuspended means the TrainJob is suspended. + TrainJobSuspended string = "Suspended" + + // TrainJobComplete means that the TrainJob has completed its execution. + TrainJobComplete string = "Complete" + + // TrainJobFailed means that the actual jobs have failed its execution. + TrainJobFailed string = "Failed" + + // TrainJobCreated means that the actual jobs creation has succeeded. + TrainJobCreated string = "Created" +) + +const ( + // TrainJobSuspendedReason is the "Suspended" condition reason. + // When the TrainJob is suspended, this is added. + TrainJobSuspendedReason string = "Suspended" + + // TrainJobResumedReason is the "Suspended" condition reason. + // When the TrainJob suspension is changed from True to False, this is added. + TrainJobResumedReason string = "Resumed" + + // TrainJobJobsCreationSucceededReason is the "Created" condition reason. + // When the creating objects succeeded after building succeeded, this is added. + TrainJobJobsCreationSucceededReason string = "JobsCreationSucceeded" + + // TrainJobJobsBuildFailedReason is the "Created" condition reason. + // When the building objects based on the TrainJob and the specified runtime failed, + // this is added. + TrainJobJobsBuildFailedReason string = "JobsBuildFailed" + + // TrainJobJobsCreationFailedReason is the "Created" condition reason. + // When the creating objects failed even though building succeeded, this is added. + TrainJobJobsCreationFailedReason string = "JobsCreationFailed" +) + // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object // +resource:path=trainjobs @@ -269,7 +306,13 @@ type ContainerOverride struct { // TrainJobStatus represents the current status of TrainJob. type TrainJobStatus struct { // Conditions for the TrainJob. - Conditions []metav1.Condition `json:"conditions,omitempty"` + // + // +optional + // +listType=map + // +listMapKey=type + // +patchStrategy=merge + // +patchMergeKey=type + Conditions []metav1.Condition `json:"conditions,omitempty" patchStrategy:"merge" patchMergeKey:"type"` // JobsStatus tracks the child Jobs in TrainJob. JobsStatus []JobStatus `json:"jobsStatus,omitempty"` diff --git a/pkg/controller.v2/trainjob_controller.go b/pkg/controller.v2/trainjob_controller.go index 95a34048e0..2fad2885be 100644 --- a/pkg/controller.v2/trainjob_controller.go +++ b/pkg/controller.v2/trainjob_controller.go @@ -22,6 +22,9 @@ import ( "fmt" "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/api/equality" + "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/tools/record" "k8s.io/klog/v2" @@ -36,6 +39,15 @@ import ( var errorUnsupportedRuntime = errors.New("the specified runtime is not supported") +type objsOpState int + +const ( + succeeded objsOpState = iota + buildFailed objsOpState = iota + creationFailed objsOpState = iota + updateFailed objsOpState = iota +) + type TrainJobReconciler struct { log logr.Logger client client.Client @@ -63,29 +75,41 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c log := ctrl.LoggerFrom(ctx).WithValues("trainJob", klog.KObj(&trainJob)) ctx = ctrl.LoggerInto(ctx, log) log.V(2).Info("Reconciling TrainJob") - if err := r.createOrUpdateObjs(ctx, &trainJob); err != nil { - return ctrl.Result{}, err + if isTrainJobFinished(&trainJob) { + log.V(5).Info("TrainJob has already been finished") + return ctrl.Result{}, nil } - // TODO (tenzen-y): Do update the status. - return ctrl.Result{}, nil -} - -func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *kubeflowv2.TrainJob) error { - log := ctrl.LoggerFrom(ctx) runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String() runtime, ok := r.runtimes[runtimeRefGK] if !ok { - return fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK) + return ctrl.Result{}, fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK) + } + opState, err := r.reconcileObjects(ctx, runtime, &trainJob) + + originStatus := trainJob.Status.DeepCopy() + setSuspendedCondition(&trainJob) + setCreatedCondition(&trainJob, opState) + if terminalCondErr := setTerminalCondition(ctx, runtime, &trainJob); terminalCondErr != nil { + return ctrl.Result{}, errors.Join(err, terminalCondErr) + } + if !equality.Semantic.DeepEqual(&trainJob, originStatus) { + return ctrl.Result{}, errors.Join(err, r.client.Status().Update(ctx, &trainJob)) } + return ctrl.Result{}, err +} + +func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobruntimes.Runtime, trainJob *kubeflowv2.TrainJob) (objsOpState, error) { + log := ctrl.LoggerFrom(ctx) + objs, err := runtime.NewObjects(ctx, trainJob) if err != nil { - return err + return buildFailed, err } for _, obj := range objs { var gvk schema.GroupVersionKind if gvk, err = apiutil.GVKForObject(obj.DeepCopyObject(), r.client.Scheme()); err != nil { - return err + return buildFailed, err } logKeysAndValues := []any{ "groupVersionKind", gvk.String(), @@ -102,21 +126,91 @@ func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *k } switch { case created: - log.V(5).Info("Succeeded to create object", logKeysAndValues) + log.V(5).Info("Succeeded to create object", logKeysAndValues...) continue case client.IgnoreAlreadyExists(creationErr) != nil: - return creationErr + return creationFailed, creationErr default: // This indicates CREATE operation has not been performed or the object has already existed in the cluster. if err = r.client.Update(ctx, obj); err != nil { - return err + return updateFailed, err } - log.V(5).Info("Succeeded to update object", logKeysAndValues) + log.V(5).Info("Succeeded to update object", logKeysAndValues...) } } + return succeeded, nil +} + +func setCreatedCondition(trainJob *kubeflowv2.TrainJob, opState objsOpState) { + var newCond metav1.Condition + switch opState { + case succeeded: + newCond = metav1.Condition{ + Type: kubeflowv2.TrainJobCreated, + Status: metav1.ConditionTrue, + Message: "Succeeded to create Jobs", + Reason: kubeflowv2.TrainJobJobsCreationSucceededReason, + } + case buildFailed: + newCond = metav1.Condition{ + Type: kubeflowv2.TrainJobCreated, + Status: metav1.ConditionFalse, + Message: "Failed to build Jobs", + Reason: kubeflowv2.TrainJobJobsBuildFailedReason, + } + // TODO (tenzen-y): Provide more granular the message based on creation or update failure. + case creationFailed, updateFailed: + newCond = metav1.Condition{ + Type: kubeflowv2.TrainJobCreated, + Status: metav1.ConditionFalse, + Message: "Failed to create Jobs", + Reason: kubeflowv2.TrainJobJobsCreationFailedReason, + } + default: + return + } + meta.SetStatusCondition(&trainJob.Status.Conditions, newCond) +} + +func setSuspendedCondition(trainJob *kubeflowv2.TrainJob) { + var newCond metav1.Condition + switch { + case ptr.Deref(trainJob.Spec.Suspend, false): + newCond = metav1.Condition{ + Type: kubeflowv2.TrainJobSuspended, + Status: metav1.ConditionTrue, + Message: "TrainJob is suspended", + Reason: kubeflowv2.TrainJobSuspendedReason, + } + case meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv2.TrainJobSuspended): + newCond = metav1.Condition{ + Type: kubeflowv2.TrainJobSuspended, + Status: metav1.ConditionFalse, + Message: "TrainJob is resumed", + Reason: kubeflowv2.TrainJobResumedReason, + } + default: + return + } + meta.SetStatusCondition(&trainJob.Status.Conditions, newCond) +} + +func setTerminalCondition(ctx context.Context, runtime jobruntimes.Runtime, trainJob *kubeflowv2.TrainJob) error { + terminalCond, err := runtime.TerminalCondition(ctx, trainJob) + if err != nil { + return err + } + if terminalCond != nil { + meta.SetStatusCondition(&trainJob.Status.Conditions, *terminalCond) + } return nil } +func isTrainJobFinished(trainJob *kubeflowv2.TrainJob) bool { + return meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv2.TrainJobComplete) || + meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv2.TrainJobFailed) +} + func runtimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind { return schema.GroupKind{ Group: ptr.Deref(runtimeRef.APIGroup, ""), diff --git a/pkg/runtime.v2/core/clustertrainingruntime.go b/pkg/runtime.v2/core/clustertrainingruntime.go index 35c35fe0c9..6f2cb8bac3 100644 --- a/pkg/runtime.v2/core/clustertrainingruntime.go +++ b/pkg/runtime.v2/core/clustertrainingruntime.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" @@ -59,6 +60,10 @@ func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *kubef return r.buildObjects(ctx, trainJob, clTrainingRuntime.Spec.Template, clTrainingRuntime.Spec.MLPolicy, clTrainingRuntime.Spec.PodGroupPolicy) } +func (r *ClusterTrainingRuntime) TerminalCondition(ctx context.Context, trainJob *kubeflowv2.TrainJob) (*metav1.Condition, error) { + return r.TrainingRuntime.TerminalCondition(ctx, trainJob) +} + func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { return nil } diff --git a/pkg/runtime.v2/core/trainingruntime.go b/pkg/runtime.v2/core/trainingruntime.go index 5a6ab569bd..44a3a420d1 100644 --- a/pkg/runtime.v2/core/trainingruntime.go +++ b/pkg/runtime.v2/core/trainingruntime.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/validation/field" @@ -127,6 +128,10 @@ func (r *TrainingRuntime) buildObjects( return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob) } +func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *kubeflowv2.TrainJob) (*metav1.Condition, error) { + return r.framework.RunTerminalConditionPlugins(ctx, trainJob) +} + func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { var builders []runtime.ReconcilerBuilder for _, ex := range r.framework.WatchExtensionPlugins() { diff --git a/pkg/runtime.v2/framework/core/framework.go b/pkg/runtime.v2/framework/core/framework.go index d6955335bb..e2ecfdca03 100644 --- a/pkg/runtime.v2/framework/core/framework.go +++ b/pkg/runtime.v2/framework/core/framework.go @@ -18,7 +18,9 @@ package core import ( "context" + "errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -29,6 +31,8 @@ import ( fwkplugins "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins" ) +var errorTooManyTerminalConditionPlugin = errors.New("too many TerminalCondition plugins are registered") + type Framework struct { registry fwkplugins.Registry plugins map[string]framework.Plugin @@ -37,6 +41,7 @@ type Framework struct { customValidationPlugins []framework.CustomValidationPlugin watchExtensionPlugins []framework.WatchExtensionPlugin componentBuilderPlugins []framework.ComponentBuilderPlugin + terminalConditionPlugins []framework.TerminalConditionPlugin } func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) { @@ -66,6 +71,9 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl if p, ok := plugin.(framework.ComponentBuilderPlugin); ok { f.componentBuilderPlugins = append(f.componentBuilderPlugins, p) } + if p, ok := plugin.(framework.TerminalConditionPlugin); ok { + f.terminalConditionPlugins = append(f.terminalConditionPlugins, p) + } } f.plugins = plugins return f, nil @@ -118,6 +126,17 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, runtimeJobTe return objs, nil } +func (f *Framework) RunTerminalConditionPlugins(ctx context.Context, trainJob *kubeflowv2.TrainJob) (*metav1.Condition, error) { + // TODO (tenzen-y): Once we provide the Configuration API, we should validate which plugin should have terminalCondition execution points. + if len(f.terminalConditionPlugins) > 1 { + return nil, errorTooManyTerminalConditionPlugin + } + if len(f.terminalConditionPlugins) != 0 { + return f.terminalConditionPlugins[0].TerminalCondition(ctx, trainJob) + } + return nil, nil +} + func (f *Framework) WatchExtensionPlugins() []framework.WatchExtensionPlugin { return f.watchExtensionPlugins } diff --git a/pkg/runtime.v2/framework/core/framework_test.go b/pkg/runtime.v2/framework/core/framework_test.go index 69255c4016..fc05c72ff8 100644 --- a/pkg/runtime.v2/framework/core/framework_test.go +++ b/pkg/runtime.v2/framework/core/framework_test.go @@ -18,6 +18,7 @@ package core import ( "context" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -31,6 +32,8 @@ import ( "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + jobsetconsts "sigs.k8s.io/jobset/pkg/constants" schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" @@ -89,6 +92,9 @@ func TestNew(t *testing.T) { &coscheduling.CoScheduling{}, &jobset.JobSet{}, }, + terminalConditionPlugins: []framework.TerminalConditionPlugin{ + &jobset.JobSet{}, + }, }, }, "indexer key for trainingRuntime and runtimeClass is an empty": { @@ -479,7 +485,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { Obj(), }, }, - // "an empty registry": {}, + "an empty registry": {}, } cmpOpts := []cmp.Option{ cmpopts.SortSlices(func(a, b client.Object) bool { @@ -518,7 +524,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { } } -func TestRunExtensionPlugins(t *testing.T) { +func TestWatchRunExtensionPlugins(t *testing.T) { cases := map[string]struct { registry fwkplugins.Registry wantPlugins []framework.WatchExtensionPlugin @@ -555,3 +561,108 @@ func TestRunExtensionPlugins(t *testing.T) { }) } } + +type fakeTerminalConditionPlugin struct{} + +var _ framework.TerminalConditionPlugin = (*fakeTerminalConditionPlugin)(nil) + +func newFakeTerminalConditionPlugin(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) { + return &fakeTerminalConditionPlugin{}, nil +} + +const fakeTerminalConditionPluginName = "fake" + +func (f fakeTerminalConditionPlugin) Name() string { return fakeTerminalConditionPluginName } +func (f fakeTerminalConditionPlugin) TerminalCondition(context.Context, *kubeflowv2.TrainJob) (*metav1.Condition, error) { + return nil, nil +} + +func TestTerminalConditionPlugins(t *testing.T) { + cases := map[string]struct { + registry fwkplugins.Registry + trainJob *kubeflowv2.TrainJob + jobSet *jobsetv1alpha2.JobSet + wantCondition *metav1.Condition + wantError error + }{ + "jobSet has not been finalized, yet": { + registry: fwkplugins.NewRegistry(), + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "testing"). + Obj(), + jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing"). + Conditions(metav1.Condition{ + Type: string(jobsetv1alpha2.JobSetSuspended), + Reason: jobsetconsts.JobSetSuspendedReason, + Message: jobsetconsts.JobSetSuspendedMessage, + Status: metav1.ConditionFalse, + }). + Obj(), + }, + "succeeded to obtain completed terminal condition": { + registry: fwkplugins.NewRegistry(), + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "testing"). + Obj(), + jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing"). + Conditions(metav1.Condition{ + Type: string(jobsetv1alpha2.JobSetCompleted), + Reason: jobsetconsts.AllJobsCompletedReason, + Message: jobsetconsts.AllJobsCompletedMessage, + Status: metav1.ConditionTrue, + }). + Obj(), + wantCondition: &metav1.Condition{ + Type: kubeflowv2.TrainJobComplete, + Reason: fmt.Sprintf("%sDueTo%s", jobsetv1alpha2.JobSetCompleted, jobsetconsts.AllJobsCompletedReason), + Message: jobsetconsts.AllJobsCompletedMessage, + Status: metav1.ConditionTrue, + }, + }, + "succeeded to obtain failed terminal condition": { + registry: fwkplugins.NewRegistry(), + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "testing"). + Obj(), + jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing"). + Conditions(metav1.Condition{ + Type: string(jobsetv1alpha2.JobSetFailed), + Reason: jobsetconsts.FailedJobsReason, + Message: jobsetconsts.FailedJobsMessage, + Status: metav1.ConditionTrue, + }). + Obj(), + wantCondition: &metav1.Condition{ + Type: kubeflowv2.TrainJobFailed, + Reason: fmt.Sprintf("%sDueTo%s", jobsetv1alpha2.JobSetFailed, jobsetconsts.FailedJobsReason), + Message: jobsetconsts.FailedJobsMessage, + Status: metav1.ConditionTrue, + }, + }, + "failed to obtain any terminal condition due to multiple terminalCondition plugin": { + registry: fwkplugins.Registry{ + jobset.Name: jobset.New, + fakeTerminalConditionPluginName: newFakeTerminalConditionPlugin, + }, + wantError: errorTooManyTerminalConditionPlugin, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + clientBuilder := testingutil.NewClientBuilder() + if tc.jobSet != nil { + clientBuilder = clientBuilder.WithObjects(tc.jobSet) + } + fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder)) + if err != nil { + t.Fatal(err) + } + gotCond, gotErr := fwk.RunTerminalConditionPlugins(ctx, tc.trainJob) + if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantCondition, gotCond); len(diff) != 0 { + t.Errorf("Unexpected terminal condition (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/pkg/runtime.v2/framework/interface.go b/pkg/runtime.v2/framework/interface.go index a35e9727a7..9e05c8fe24 100644 --- a/pkg/runtime.v2/framework/interface.go +++ b/pkg/runtime.v2/framework/interface.go @@ -18,6 +18,7 @@ package framework import ( "context" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" @@ -55,3 +56,8 @@ type ComponentBuilderPlugin interface { Plugin Build(ctx context.Context, runtimeJobTemplate client.Object, info *runtime.Info, trainJob *kubeflowv2.TrainJob) (client.Object, error) } + +type TerminalConditionPlugin interface { + Plugin + TerminalCondition(ctx context.Context, trainJob *kubeflowv2.TrainJob) (*metav1.Condition, error) +} diff --git a/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go b/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go index 57be2432c8..3f0484d37c 100644 --- a/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go +++ b/pkg/runtime.v2/framework/plugins/coscheduling/coscheduling.go @@ -100,7 +100,7 @@ func (c *CoScheduling) EnforcePodGroupPolicy(info *runtime.Info, trainJob *kubef return nil } -func (c *CoScheduling) Build(ctx context.Context, runtimeJobTemplate client.Object, info *runtime.Info, trainJob *kubeflowv2.TrainJob) (client.Object, error) { +func (c *CoScheduling) Build(ctx context.Context, _ client.Object, info *runtime.Info, trainJob *kubeflowv2.TrainJob) (client.Object, error) { if info == nil || info.RuntimePolicy.PodGroupPolicy == nil || info.RuntimePolicy.PodGroupPolicy.Coscheduling == nil || trainJob == nil { return nil, nil } diff --git a/pkg/runtime.v2/framework/plugins/jobset/jobset.go b/pkg/runtime.v2/framework/plugins/jobset/jobset.go index ef04890b39..e372b34857 100644 --- a/pkg/runtime.v2/framework/plugins/jobset/jobset.go +++ b/pkg/runtime.v2/framework/plugins/jobset/jobset.go @@ -50,6 +50,7 @@ type JobSet struct { var _ framework.WatchExtensionPlugin = (*JobSet)(nil) var _ framework.ComponentBuilderPlugin = (*JobSet)(nil) +var _ framework.TerminalConditionPlugin = (*JobSet)(nil) const Name = constants.JobSetKind @@ -126,6 +127,24 @@ func jobSetIsSuspended(jobSet *jobsetv1alpha2.JobSet) bool { return ptr.Deref(jobSet.Spec.Suspend, false) } +func (j *JobSet) TerminalCondition(ctx context.Context, trainJob *kubeflowv2.TrainJob) (*metav1.Condition, error) { + jobSet := &jobsetv1alpha2.JobSet{} + if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), jobSet); err != nil { + return nil, err + } + if completed := meta.FindStatusCondition(jobSet.Status.Conditions, string(jobsetv1alpha2.JobSetCompleted)); completed != nil && completed.Status == metav1.ConditionTrue { + completed.Reason = fmt.Sprintf("%sDueTo%s", completed.Type, completed.Reason) + completed.Type = kubeflowv2.TrainJobComplete + return completed, nil + } + if failed := meta.FindStatusCondition(jobSet.Status.Conditions, string(jobsetv1alpha2.JobSetFailed)); failed != nil && failed.Status == metav1.ConditionTrue { + failed.Reason = fmt.Sprintf("%sDueTo%s", failed.Type, failed.Reason) + failed.Type = kubeflowv2.TrainJobFailed + return failed, nil + } + return nil, nil +} + func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder { if _, err := j.restMapper.RESTMapping( schema.GroupKind{Group: jobsetv1alpha2.GroupVersion.Group, Kind: constants.JobSetKind}, diff --git a/pkg/runtime.v2/interface.go b/pkg/runtime.v2/interface.go index 8c735ad4f1..b23e7f35e0 100644 --- a/pkg/runtime.v2/interface.go +++ b/pkg/runtime.v2/interface.go @@ -19,6 +19,7 @@ package runtimev2 import ( "context" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" @@ -31,6 +32,7 @@ type ReconcilerBuilder func(*builder.Builder, client.Client) *builder.Builder type Runtime interface { NewObjects(ctx context.Context, trainJob *kubeflowv2.TrainJob) ([]client.Object, error) + TerminalCondition(ctx context.Context, trainJob *kubeflowv2.TrainJob) (*metav1.Condition, error) EventHandlerRegistrars() []ReconcilerBuilder ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) } diff --git a/pkg/util.v2/testing/wrapper.go b/pkg/util.v2/testing/wrapper.go index 4f19ba3d2c..fdbb3dd6c7 100644 --- a/pkg/util.v2/testing/wrapper.go +++ b/pkg/util.v2/testing/wrapper.go @@ -282,6 +282,13 @@ func (j *JobSetWrapper) Annotation(key, value string) *JobSetWrapper { return j } +func (j *JobSetWrapper) Conditions(conditions ...metav1.Condition) *JobSetWrapper { + if len(conditions) != 0 { + j.Status.Conditions = append(j.Status.Conditions, conditions...) + } + return j +} + func (j *JobSetWrapper) Obj() *jobsetv1alpha2.JobSet { return &j.JobSet } diff --git a/test/integration/controller.v2/trainjob_controller_test.go b/test/integration/controller.v2/trainjob_controller_test.go index 39ce245227..78012cb1ef 100644 --- a/test/integration/controller.v2/trainjob_controller_test.go +++ b/test/integration/controller.v2/trainjob_controller_test.go @@ -18,16 +18,17 @@ package controllerv2 import ( "fmt" - "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + jobsetconsts "sigs.k8s.io/jobset/pkg/constants" schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" @@ -332,6 +333,159 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) + + ginkgo.It("Should succeeded to reconcile TrainJob conditions with Complete condition", func() { + ginkgo.By("Creating TrainingRuntime and suspended TrainJob") + gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).Should(gomega.Succeed()) + gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) + + ginkgo.By("Checking if JobSet and PodGroup are created") + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, trainJobKey, &jobsetv1alpha2.JobSet{})).Should(gomega.Succeed()) + g.Expect(k8sClient.Get(ctx, trainJobKey, &schedulerpluginsv1alpha1.PodGroup{})).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Checking if TrainJob has Suspended and Created conditions") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &kubeflowv2.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + g.Expect(gotTrainJob.Status.Conditions).Should(gomega.BeComparableTo([]metav1.Condition{ + { + Type: kubeflowv2.TrainJobSuspended, + Status: metav1.ConditionTrue, + Reason: kubeflowv2.TrainJobSuspendedReason, + Message: "TrainJob is suspended", + }, + { + Type: kubeflowv2.TrainJobCreated, + Status: metav1.ConditionTrue, + Reason: kubeflowv2.TrainJobJobsCreationSucceededReason, + Message: "Succeeded to create Jobs", + }, + }, util.IgnoreConditions)) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Checking if the TrainJob has Resumed and Created conditions after unsuspended") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &kubeflowv2.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + gotTrainJob.Spec.Suspend = ptr.To(false) + g.Expect(k8sClient.Update(ctx, gotTrainJob)).Should(gomega.Succeed()) + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + g.Expect(gotTrainJob.Status.Conditions).Should(gomega.BeComparableTo([]metav1.Condition{ + { + Type: kubeflowv2.TrainJobSuspended, + Status: metav1.ConditionFalse, + Reason: kubeflowv2.TrainJobResumedReason, + Message: "TrainJob is resumed", + }, + { + Type: kubeflowv2.TrainJobCreated, + Status: metav1.ConditionTrue, + Reason: kubeflowv2.TrainJobJobsCreationSucceededReason, + Message: "Succeeded to create Jobs", + }, + }, util.IgnoreConditions)) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Updating the JobSet condition with Completed") + gomega.Eventually(func(g gomega.Gomega) { + jobSet := &jobsetv1alpha2.JobSet{} + g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) + meta.SetStatusCondition(&jobSet.Status.Conditions, metav1.Condition{ + Type: string(jobsetv1alpha2.JobSetCompleted), + Reason: jobsetconsts.AllJobsCompletedReason, + Message: jobsetconsts.AllJobsCompletedMessage, + Status: metav1.ConditionTrue, + }) + g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Checking if the TranJob has Resumed, Created, and Completed conditions") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &kubeflowv2.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + g.Expect(gotTrainJob.Status.Conditions).Should(gomega.BeComparableTo([]metav1.Condition{ + { + Type: kubeflowv2.TrainJobSuspended, + Status: metav1.ConditionFalse, + Reason: kubeflowv2.TrainJobResumedReason, + Message: "TrainJob is resumed", + }, + { + Type: kubeflowv2.TrainJobCreated, + Status: metav1.ConditionTrue, + Reason: kubeflowv2.TrainJobJobsCreationSucceededReason, + Message: "Succeeded to create Jobs", + }, + { + Type: kubeflowv2.TrainJobComplete, + Status: metav1.ConditionTrue, + Reason: fmt.Sprintf("%sDueTo%s", jobsetv1alpha2.JobSetCompleted, jobsetconsts.AllJobsCompletedReason), + Message: jobsetconsts.AllJobsCompletedMessage, + }, + }, util.IgnoreConditions)) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.It("Should succeeded to reconcile TrainJob conditions with Failed condition", func() { + ginkgo.By("Creating TrainingRuntime and suspended TrainJob") + gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).Should(gomega.Succeed()) + gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) + + ginkgo.By("Checking if JobSet and PodGroup are created") + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, trainJobKey, &jobsetv1alpha2.JobSet{})).Should(gomega.Succeed()) + g.Expect(k8sClient.Get(ctx, trainJobKey, &schedulerpluginsv1alpha1.PodGroup{})).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Unsuspending the TrainJob") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &kubeflowv2.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + gotTrainJob.Spec.Suspend = ptr.To(false) + g.Expect(k8sClient.Update(ctx, gotTrainJob)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Updating the JobSet condition with Failed") + gomega.Eventually(func(g gomega.Gomega) { + jobSet := &jobsetv1alpha2.JobSet{} + g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) + meta.SetStatusCondition(&jobSet.Status.Conditions, metav1.Condition{ + Type: string(jobsetv1alpha2.JobSetFailed), + Reason: jobsetconsts.FailedJobsReason, + Message: jobsetconsts.FailedJobsMessage, + Status: metav1.ConditionTrue, + }) + g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Checking if the TranJob has Resumed, Created, and Failed conditions") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &kubeflowv2.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + g.Expect(gotTrainJob.Status.Conditions).Should(gomega.BeComparableTo([]metav1.Condition{ + { + Type: kubeflowv2.TrainJobSuspended, + Status: metav1.ConditionFalse, + Reason: kubeflowv2.TrainJobResumedReason, + Message: "TrainJob is resumed", + }, + { + Type: kubeflowv2.TrainJobCreated, + Status: metav1.ConditionTrue, + Reason: kubeflowv2.TrainJobJobsCreationSucceededReason, + Message: "Succeeded to create Jobs", + }, + { + Type: kubeflowv2.TrainJobFailed, + Status: metav1.ConditionTrue, + Reason: fmt.Sprintf("%sDueTo%s", jobsetv1alpha2.JobSetFailed, jobsetconsts.FailedJobsReason), + Message: jobsetconsts.FailedJobsMessage, + }, + }, util.IgnoreConditions)) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) }) }) diff --git a/test/integration/framework/framework.go b/test/integration/framework/framework.go index a86c433d7e..832fce1867 100644 --- a/test/integration/framework/framework.go +++ b/test/integration/framework/framework.go @@ -26,12 +26,12 @@ import ( "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" + "go.uber.org/zap/zapcore" "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/rest" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/envtest" - "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" "sigs.k8s.io/controller-runtime/pkg/manager" metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" @@ -51,7 +51,7 @@ type Framework struct { } func (f *Framework) Init() *rest.Config { - log.SetLogger(zap.New(zap.WriteTo(ginkgo.GinkgoWriter), zap.UseDevMode(true))) + ctrl.SetLogger(zap.New(zap.WriteTo(ginkgo.GinkgoWriter), zap.Level(zapcore.Level(-5)), zap.UseDevMode(true))) ginkgo.By("bootstrapping test environment") f.testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{ diff --git a/test/util/constants.go b/test/util/constants.go index a0b9d8a665..434bfeaac2 100644 --- a/test/util/constants.go +++ b/test/util/constants.go @@ -34,4 +34,7 @@ var ( cmpopts.IgnoreTypes(metav1.TypeMeta{}), cmpopts.IgnoreFields(metav1.ObjectMeta{}, "UID", "ResourceVersion", "Generation", "CreationTimestamp", "ManagedFields"), } + IgnoreConditions = cmp.Options{ + cmpopts.IgnoreFields(metav1.Condition{}, "LastTransitionTime", "ObservedGeneration"), + } )