Skip to content

Commit

Permalink
KEP-2170: Implement TrainJob conditions
Browse files Browse the repository at this point in the history
Signed-off-by: Yuki Iwai <[email protected]>
  • Loading branch information
tenzen-y committed Nov 7, 2024
1 parent 9e46f9d commit 8ca03df
Show file tree
Hide file tree
Showing 17 changed files with 503 additions and 23 deletions.
1 change: 0 additions & 1 deletion hack/violation_exception_v2alpha1.list
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,PodSpecOverride,Volumes
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TorchElasticPolicy,Metrics
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TrainJobSpec,PodSpecOverrides
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TrainJobStatus,Conditions
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TrainJobStatus,JobsStatus
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,Trainer,Args
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,Trainer,Command
Expand Down
3 changes: 3 additions & 0 deletions manifests/v2/base/crds/kubeflow.org_trainjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3055,6 +3055,9 @@ spec:
- type
type: object
type: array
x-kubernetes-list-map-keys:
- type
x-kubernetes-list-type: map
jobsStatus:
description: JobsStatus tracks the child Jobs in TrainJob.
items:
Expand Down
10 changes: 10 additions & 0 deletions pkg/apis/kubeflow.org/v2alpha1/openapi_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 44 additions & 1 deletion pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,43 @@ type TrainJob struct {
Status TrainJobStatus `json:"status,omitempty"`
}

const (
// TrainJobSuspended means the TrainJob is suspended.
TrainJobSuspended string = "Suspended"

// TrainJobComplete means that the TrainJob has completed its execution.
TrainJobComplete string = "Complete"

// TrainJobFailed means that the actual jobs have failed its execution.
TrainJobFailed string = "Failed"

// TrainJobCreated means that the actual jobs creation has succeeded.
TrainJobCreated string = "Created"
)

const (
// TrainJobSuspendedReason is the "Suspended" condition reason.
// When the TrainJob is suspended, this is added.
TrainJobSuspendedReason string = "Suspended"

// TrainJobResumedReason is the "Suspended" condition reason.
// When the TrainJob suspension is changed from True to False, this is added.
TrainJobResumedReason string = "Resumed"

// TrainJobJobsCreationSucceededReason is the "Created" condition reason.
// When the creating objects succeeded after building succeeded, this is added.
TrainJobJobsCreationSucceededReason string = "JobsCreationSucceeded"

// TrainJobJobsBuildFailedReason is the "Created" condition reason.
// When the building objects based on the TrainJob and the specified runtime failed,
// this is added.
TrainJobJobsBuildFailedReason string = "JobsBuildFailed"

// TrainJobJobsCreationFailedReason is the "Created" condition reason.
// When the creating objects failed even though building succeeded, this is added.
TrainJobJobsCreationFailedReason string = "JobsCreationFailed"
)

// +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object
// +resource:path=trainjobs

Expand Down Expand Up @@ -269,7 +306,13 @@ type ContainerOverride struct {
// TrainJobStatus represents the current status of TrainJob.
type TrainJobStatus struct {
// Conditions for the TrainJob.
Conditions []metav1.Condition `json:"conditions,omitempty"`
//
// +optional
// +listType=map
// +listMapKey=type
// +patchStrategy=merge
// +patchMergeKey=type
Conditions []metav1.Condition `json:"conditions,omitempty" patchStrategy:"merge" patchMergeKey:"type"`

// JobsStatus tracks the child Jobs in TrainJob.
JobsStatus []JobStatus `json:"jobsStatus,omitempty"`
Expand Down
124 changes: 109 additions & 15 deletions pkg/controller.v2/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ import (
"fmt"

"github.com/go-logr/logr"
"k8s.io/apimachinery/pkg/api/equality"
"k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
Expand All @@ -36,6 +39,15 @@ import (

var errorUnsupportedRuntime = errors.New("the specified runtime is not supported")

type objsOpState int

const (
succeeded objsOpState = iota
buildFailed objsOpState = iota
creationFailed objsOpState = iota
updateFailed objsOpState = iota
)

type TrainJobReconciler struct {
log logr.Logger
client client.Client
Expand Down Expand Up @@ -63,29 +75,41 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
log := ctrl.LoggerFrom(ctx).WithValues("trainJob", klog.KObj(&trainJob))
ctx = ctrl.LoggerInto(ctx, log)
log.V(2).Info("Reconciling TrainJob")
if err := r.createOrUpdateObjs(ctx, &trainJob); err != nil {
return ctrl.Result{}, err
if isTrainJobFinished(&trainJob) {
log.V(5).Info("TrainJob has already been finished")
return ctrl.Result{}, nil
}
// TODO (tenzen-y): Do update the status.
return ctrl.Result{}, nil
}

func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *kubeflowv2.TrainJob) error {
log := ctrl.LoggerFrom(ctx)

runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
runtime, ok := r.runtimes[runtimeRefGK]
if !ok {
return fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
return ctrl.Result{}, fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
}
opState, err := r.reconcileObjects(ctx, runtime, &trainJob)

originStatus := trainJob.Status.DeepCopy()
setSuspendedCondition(&trainJob)
setCreatedCondition(&trainJob, opState)
if terminalCondErr := setTerminalCondition(ctx, runtime, &trainJob); terminalCondErr != nil {
return ctrl.Result{}, errors.Join(err, terminalCondErr)
}
if !equality.Semantic.DeepEqual(&trainJob, originStatus) {
return ctrl.Result{}, errors.Join(err, r.client.Status().Update(ctx, &trainJob))
}
return ctrl.Result{}, err
}

func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobruntimes.Runtime, trainJob *kubeflowv2.TrainJob) (objsOpState, error) {
log := ctrl.LoggerFrom(ctx)

objs, err := runtime.NewObjects(ctx, trainJob)
if err != nil {
return err
return buildFailed, err
}
for _, obj := range objs {
var gvk schema.GroupVersionKind
if gvk, err = apiutil.GVKForObject(obj.DeepCopyObject(), r.client.Scheme()); err != nil {
return err
return buildFailed, err
}
logKeysAndValues := []any{
"groupVersionKind", gvk.String(),
Expand All @@ -102,21 +126,91 @@ func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *k
}
switch {
case created:
log.V(5).Info("Succeeded to create object", logKeysAndValues)
log.V(5).Info("Succeeded to create object", logKeysAndValues...)
continue
case client.IgnoreAlreadyExists(creationErr) != nil:
return creationErr
return creationFailed, creationErr
default:
// This indicates CREATE operation has not been performed or the object has already existed in the cluster.
if err = r.client.Update(ctx, obj); err != nil {
return err
return updateFailed, err
}
log.V(5).Info("Succeeded to update object", logKeysAndValues)
log.V(5).Info("Succeeded to update object", logKeysAndValues...)
}
}
return succeeded, nil
}

func setCreatedCondition(trainJob *kubeflowv2.TrainJob, opState objsOpState) {
var newCond metav1.Condition
switch opState {
case succeeded:
newCond = metav1.Condition{
Type: kubeflowv2.TrainJobCreated,
Status: metav1.ConditionTrue,
Message: "Succeeded to create Jobs",
Reason: kubeflowv2.TrainJobJobsCreationSucceededReason,
}
case buildFailed:
newCond = metav1.Condition{
Type: kubeflowv2.TrainJobCreated,
Status: metav1.ConditionFalse,
Message: "Failed to build Jobs",
Reason: kubeflowv2.TrainJobJobsBuildFailedReason,
}
// TODO (tenzen-y): Provide more granular the message based on creation or update failure.
case creationFailed, updateFailed:
newCond = metav1.Condition{
Type: kubeflowv2.TrainJobCreated,
Status: metav1.ConditionFalse,
Message: "Failed to create Jobs",
Reason: kubeflowv2.TrainJobJobsCreationFailedReason,
}
default:
return
}
meta.SetStatusCondition(&trainJob.Status.Conditions, newCond)
}

func setSuspendedCondition(trainJob *kubeflowv2.TrainJob) {
var newCond metav1.Condition
switch {
case ptr.Deref(trainJob.Spec.Suspend, false):
newCond = metav1.Condition{
Type: kubeflowv2.TrainJobSuspended,
Status: metav1.ConditionTrue,
Message: "TrainJob is suspended",
Reason: kubeflowv2.TrainJobSuspendedReason,
}
case meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv2.TrainJobSuspended):
newCond = metav1.Condition{
Type: kubeflowv2.TrainJobSuspended,
Status: metav1.ConditionFalse,
Message: "TrainJob is resumed",
Reason: kubeflowv2.TrainJobResumedReason,
}
default:
return
}
meta.SetStatusCondition(&trainJob.Status.Conditions, newCond)
}

func setTerminalCondition(ctx context.Context, runtime jobruntimes.Runtime, trainJob *kubeflowv2.TrainJob) error {
terminalCond, err := runtime.TerminalCondition(ctx, trainJob)
if err != nil {
return err
}
if terminalCond != nil {
meta.SetStatusCondition(&trainJob.Status.Conditions, *terminalCond)
}
return nil
}

func isTrainJobFinished(trainJob *kubeflowv2.TrainJob) bool {
return meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv2.TrainJobComplete) ||
meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv2.TrainJobFailed)
}

func runtimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind {
return schema.GroupKind{
Group: ptr.Deref(runtimeRef.APIGroup, ""),
Expand Down
5 changes: 5 additions & 0 deletions pkg/runtime.v2/core/clustertrainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"fmt"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -59,6 +60,10 @@ func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *kubef
return r.buildObjects(ctx, trainJob, clTrainingRuntime.Spec.Template, clTrainingRuntime.Spec.MLPolicy, clTrainingRuntime.Spec.PodGroupPolicy)
}

func (r *ClusterTrainingRuntime) TerminalCondition(ctx context.Context, trainJob *kubeflowv2.TrainJob) (*metav1.Condition, error) {
return r.TrainingRuntime.TerminalCondition(ctx, trainJob)
}

func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
return nil
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/runtime.v2/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"errors"
"fmt"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
Expand Down Expand Up @@ -127,6 +128,10 @@ func (r *TrainingRuntime) buildObjects(
return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob)
}

func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *kubeflowv2.TrainJob) (*metav1.Condition, error) {
return r.framework.RunTerminalConditionPlugins(ctx, trainJob)
}

func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
var builders []runtime.ReconcilerBuilder
for _, ex := range r.framework.WatchExtensionPlugins() {
Expand Down
19 changes: 19 additions & 0 deletions pkg/runtime.v2/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ package core

import (
"context"
"errors"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
Expand All @@ -29,6 +31,8 @@ import (
fwkplugins "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins"
)

var errorTooManyTerminalConditionPlugin = errors.New("too many TerminalCondition plugins are registered")

type Framework struct {
registry fwkplugins.Registry
plugins map[string]framework.Plugin
Expand All @@ -37,6 +41,7 @@ type Framework struct {
customValidationPlugins []framework.CustomValidationPlugin
watchExtensionPlugins []framework.WatchExtensionPlugin
componentBuilderPlugins []framework.ComponentBuilderPlugin
terminalConditionPlugins []framework.TerminalConditionPlugin
}

func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) {
Expand Down Expand Up @@ -66,6 +71,9 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl
if p, ok := plugin.(framework.ComponentBuilderPlugin); ok {
f.componentBuilderPlugins = append(f.componentBuilderPlugins, p)
}
if p, ok := plugin.(framework.TerminalConditionPlugin); ok {
f.terminalConditionPlugins = append(f.terminalConditionPlugins, p)
}
}
f.plugins = plugins
return f, nil
Expand Down Expand Up @@ -118,6 +126,17 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, runtimeJobTe
return objs, nil
}

func (f *Framework) RunTerminalConditionPlugins(ctx context.Context, trainJob *kubeflowv2.TrainJob) (*metav1.Condition, error) {
// TODO (tenzen-y): Once we provide the Configuration API, we should validate which plugin should have terminalCondition execution points.
if len(f.terminalConditionPlugins) > 1 {
return nil, errorTooManyTerminalConditionPlugin
}
if len(f.terminalConditionPlugins) != 0 {
return f.terminalConditionPlugins[0].TerminalCondition(ctx, trainJob)
}
return nil, nil
}

func (f *Framework) WatchExtensionPlugins() []framework.WatchExtensionPlugin {
return f.watchExtensionPlugins
}
Loading

0 comments on commit 8ca03df

Please sign in to comment.