Skip to content

Commit 8ca03df

Browse files
committed
KEP-2170: Implement TrainJob conditions
Signed-off-by: Yuki Iwai <[email protected]>
1 parent 9e46f9d commit 8ca03df

File tree

17 files changed

+503
-23
lines changed

17 files changed

+503
-23
lines changed

hack/violation_exception_v2alpha1.list

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/
1313
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,PodSpecOverride,Volumes
1414
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TorchElasticPolicy,Metrics
1515
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TrainJobSpec,PodSpecOverrides
16-
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TrainJobStatus,Conditions
1716
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,TrainJobStatus,JobsStatus
1817
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,Trainer,Args
1918
API rule violation: list_type_missing,github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1,Trainer,Command

manifests/v2/base/crds/kubeflow.org_trainjobs.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3055,6 +3055,9 @@ spec:
30553055
- type
30563056
type: object
30573057
type: array
3058+
x-kubernetes-list-map-keys:
3059+
- type
3060+
x-kubernetes-list-type: map
30583061
jobsStatus:
30593062
description: JobsStatus tracks the child Jobs in TrainJob.
30603063
items:

pkg/apis/kubeflow.org/v2alpha1/openapi_generated.go

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,43 @@ type TrainJob struct {
4848
Status TrainJobStatus `json:"status,omitempty"`
4949
}
5050

51+
const (
52+
// TrainJobSuspended means the TrainJob is suspended.
53+
TrainJobSuspended string = "Suspended"
54+
55+
// TrainJobComplete means that the TrainJob has completed its execution.
56+
TrainJobComplete string = "Complete"
57+
58+
// TrainJobFailed means that the actual jobs have failed its execution.
59+
TrainJobFailed string = "Failed"
60+
61+
// TrainJobCreated means that the actual jobs creation has succeeded.
62+
TrainJobCreated string = "Created"
63+
)
64+
65+
const (
66+
// TrainJobSuspendedReason is the "Suspended" condition reason.
67+
// When the TrainJob is suspended, this is added.
68+
TrainJobSuspendedReason string = "Suspended"
69+
70+
// TrainJobResumedReason is the "Suspended" condition reason.
71+
// When the TrainJob suspension is changed from True to False, this is added.
72+
TrainJobResumedReason string = "Resumed"
73+
74+
// TrainJobJobsCreationSucceededReason is the "Created" condition reason.
75+
// When the creating objects succeeded after building succeeded, this is added.
76+
TrainJobJobsCreationSucceededReason string = "JobsCreationSucceeded"
77+
78+
// TrainJobJobsBuildFailedReason is the "Created" condition reason.
79+
// When the building objects based on the TrainJob and the specified runtime failed,
80+
// this is added.
81+
TrainJobJobsBuildFailedReason string = "JobsBuildFailed"
82+
83+
// TrainJobJobsCreationFailedReason is the "Created" condition reason.
84+
// When the creating objects failed even though building succeeded, this is added.
85+
TrainJobJobsCreationFailedReason string = "JobsCreationFailed"
86+
)
87+
5188
// +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object
5289
// +resource:path=trainjobs
5390

@@ -269,7 +306,13 @@ type ContainerOverride struct {
269306
// TrainJobStatus represents the current status of TrainJob.
270307
type TrainJobStatus struct {
271308
// Conditions for the TrainJob.
272-
Conditions []metav1.Condition `json:"conditions,omitempty"`
309+
//
310+
// +optional
311+
// +listType=map
312+
// +listMapKey=type
313+
// +patchStrategy=merge
314+
// +patchMergeKey=type
315+
Conditions []metav1.Condition `json:"conditions,omitempty" patchStrategy:"merge" patchMergeKey:"type"`
273316

274317
// JobsStatus tracks the child Jobs in TrainJob.
275318
JobsStatus []JobStatus `json:"jobsStatus,omitempty"`

pkg/controller.v2/trainjob_controller.go

Lines changed: 109 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ import (
2222
"fmt"
2323

2424
"github.com/go-logr/logr"
25+
"k8s.io/apimachinery/pkg/api/equality"
26+
"k8s.io/apimachinery/pkg/api/meta"
27+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2528
"k8s.io/apimachinery/pkg/runtime/schema"
2629
"k8s.io/client-go/tools/record"
2730
"k8s.io/klog/v2"
@@ -36,6 +39,15 @@ import (
3639

3740
var errorUnsupportedRuntime = errors.New("the specified runtime is not supported")
3841

42+
type objsOpState int
43+
44+
const (
45+
succeeded objsOpState = iota
46+
buildFailed objsOpState = iota
47+
creationFailed objsOpState = iota
48+
updateFailed objsOpState = iota
49+
)
50+
3951
type TrainJobReconciler struct {
4052
log logr.Logger
4153
client client.Client
@@ -63,29 +75,41 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
6375
log := ctrl.LoggerFrom(ctx).WithValues("trainJob", klog.KObj(&trainJob))
6476
ctx = ctrl.LoggerInto(ctx, log)
6577
log.V(2).Info("Reconciling TrainJob")
66-
if err := r.createOrUpdateObjs(ctx, &trainJob); err != nil {
67-
return ctrl.Result{}, err
78+
if isTrainJobFinished(&trainJob) {
79+
log.V(5).Info("TrainJob has already been finished")
80+
return ctrl.Result{}, nil
6881
}
69-
// TODO (tenzen-y): Do update the status.
70-
return ctrl.Result{}, nil
71-
}
72-
73-
func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *kubeflowv2.TrainJob) error {
74-
log := ctrl.LoggerFrom(ctx)
7582

7683
runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
7784
runtime, ok := r.runtimes[runtimeRefGK]
7885
if !ok {
79-
return fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
86+
return ctrl.Result{}, fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
87+
}
88+
opState, err := r.reconcileObjects(ctx, runtime, &trainJob)
89+
90+
originStatus := trainJob.Status.DeepCopy()
91+
setSuspendedCondition(&trainJob)
92+
setCreatedCondition(&trainJob, opState)
93+
if terminalCondErr := setTerminalCondition(ctx, runtime, &trainJob); terminalCondErr != nil {
94+
return ctrl.Result{}, errors.Join(err, terminalCondErr)
95+
}
96+
if !equality.Semantic.DeepEqual(&trainJob, originStatus) {
97+
return ctrl.Result{}, errors.Join(err, r.client.Status().Update(ctx, &trainJob))
8098
}
99+
return ctrl.Result{}, err
100+
}
101+
102+
func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobruntimes.Runtime, trainJob *kubeflowv2.TrainJob) (objsOpState, error) {
103+
log := ctrl.LoggerFrom(ctx)
104+
81105
objs, err := runtime.NewObjects(ctx, trainJob)
82106
if err != nil {
83-
return err
107+
return buildFailed, err
84108
}
85109
for _, obj := range objs {
86110
var gvk schema.GroupVersionKind
87111
if gvk, err = apiutil.GVKForObject(obj.DeepCopyObject(), r.client.Scheme()); err != nil {
88-
return err
112+
return buildFailed, err
89113
}
90114
logKeysAndValues := []any{
91115
"groupVersionKind", gvk.String(),
@@ -102,21 +126,91 @@ func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *k
102126
}
103127
switch {
104128
case created:
105-
log.V(5).Info("Succeeded to create object", logKeysAndValues)
129+
log.V(5).Info("Succeeded to create object", logKeysAndValues...)
106130
continue
107131
case client.IgnoreAlreadyExists(creationErr) != nil:
108-
return creationErr
132+
return creationFailed, creationErr
109133
default:
110134
// This indicates CREATE operation has not been performed or the object has already existed in the cluster.
111135
if err = r.client.Update(ctx, obj); err != nil {
112-
return err
136+
return updateFailed, err
113137
}
114-
log.V(5).Info("Succeeded to update object", logKeysAndValues)
138+
log.V(5).Info("Succeeded to update object", logKeysAndValues...)
115139
}
116140
}
141+
return succeeded, nil
142+
}
143+
144+
func setCreatedCondition(trainJob *kubeflowv2.TrainJob, opState objsOpState) {
145+
var newCond metav1.Condition
146+
switch opState {
147+
case succeeded:
148+
newCond = metav1.Condition{
149+
Type: kubeflowv2.TrainJobCreated,
150+
Status: metav1.ConditionTrue,
151+
Message: "Succeeded to create Jobs",
152+
Reason: kubeflowv2.TrainJobJobsCreationSucceededReason,
153+
}
154+
case buildFailed:
155+
newCond = metav1.Condition{
156+
Type: kubeflowv2.TrainJobCreated,
157+
Status: metav1.ConditionFalse,
158+
Message: "Failed to build Jobs",
159+
Reason: kubeflowv2.TrainJobJobsBuildFailedReason,
160+
}
161+
// TODO (tenzen-y): Provide more granular the message based on creation or update failure.
162+
case creationFailed, updateFailed:
163+
newCond = metav1.Condition{
164+
Type: kubeflowv2.TrainJobCreated,
165+
Status: metav1.ConditionFalse,
166+
Message: "Failed to create Jobs",
167+
Reason: kubeflowv2.TrainJobJobsCreationFailedReason,
168+
}
169+
default:
170+
return
171+
}
172+
meta.SetStatusCondition(&trainJob.Status.Conditions, newCond)
173+
}
174+
175+
func setSuspendedCondition(trainJob *kubeflowv2.TrainJob) {
176+
var newCond metav1.Condition
177+
switch {
178+
case ptr.Deref(trainJob.Spec.Suspend, false):
179+
newCond = metav1.Condition{
180+
Type: kubeflowv2.TrainJobSuspended,
181+
Status: metav1.ConditionTrue,
182+
Message: "TrainJob is suspended",
183+
Reason: kubeflowv2.TrainJobSuspendedReason,
184+
}
185+
case meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv2.TrainJobSuspended):
186+
newCond = metav1.Condition{
187+
Type: kubeflowv2.TrainJobSuspended,
188+
Status: metav1.ConditionFalse,
189+
Message: "TrainJob is resumed",
190+
Reason: kubeflowv2.TrainJobResumedReason,
191+
}
192+
default:
193+
return
194+
}
195+
meta.SetStatusCondition(&trainJob.Status.Conditions, newCond)
196+
}
197+
198+
func setTerminalCondition(ctx context.Context, runtime jobruntimes.Runtime, trainJob *kubeflowv2.TrainJob) error {
199+
terminalCond, err := runtime.TerminalCondition(ctx, trainJob)
200+
if err != nil {
201+
return err
202+
}
203+
if terminalCond != nil {
204+
meta.SetStatusCondition(&trainJob.Status.Conditions, *terminalCond)
205+
}
117206
return nil
118207
}
119208

209+
func isTrainJobFinished(trainJob *kubeflowv2.TrainJob) bool {
210+
return meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv2.TrainJobComplete) ||
211+
meta.IsStatusConditionTrue(trainJob.Status.Conditions, kubeflowv2.TrainJobFailed)
212+
}
213+
120214
func runtimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind {
121215
return schema.GroupKind{
122216
Group: ptr.Deref(runtimeRef.APIGroup, ""),

pkg/runtime.v2/core/clustertrainingruntime.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"errors"
2222
"fmt"
2323

24+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2425
"k8s.io/apimachinery/pkg/runtime/schema"
2526
"k8s.io/apimachinery/pkg/util/validation/field"
2627
"sigs.k8s.io/controller-runtime/pkg/client"
@@ -59,6 +60,10 @@ func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *kubef
5960
return r.buildObjects(ctx, trainJob, clTrainingRuntime.Spec.Template, clTrainingRuntime.Spec.MLPolicy, clTrainingRuntime.Spec.PodGroupPolicy)
6061
}
6162

63+
func (r *ClusterTrainingRuntime) TerminalCondition(ctx context.Context, trainJob *kubeflowv2.TrainJob) (*metav1.Condition, error) {
64+
return r.TrainingRuntime.TerminalCondition(ctx, trainJob)
65+
}
66+
6267
func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
6368
return nil
6469
}

pkg/runtime.v2/core/trainingruntime.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"errors"
2222
"fmt"
23+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2324

2425
"k8s.io/apimachinery/pkg/runtime/schema"
2526
"k8s.io/apimachinery/pkg/util/validation/field"
@@ -127,6 +128,10 @@ func (r *TrainingRuntime) buildObjects(
127128
return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob)
128129
}
129130

131+
func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *kubeflowv2.TrainJob) (*metav1.Condition, error) {
132+
return r.framework.RunTerminalConditionPlugins(ctx, trainJob)
133+
}
134+
130135
func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
131136
var builders []runtime.ReconcilerBuilder
132137
for _, ex := range r.framework.WatchExtensionPlugins() {

pkg/runtime.v2/framework/core/framework.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ package core
1818

1919
import (
2020
"context"
21+
"errors"
2122

23+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2224
"k8s.io/apimachinery/pkg/util/validation/field"
2325
"sigs.k8s.io/controller-runtime/pkg/client"
2426
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
@@ -29,6 +31,8 @@ import (
2931
fwkplugins "github.com/kubeflow/training-operator/pkg/runtime.v2/framework/plugins"
3032
)
3133

34+
var errorTooManyTerminalConditionPlugin = errors.New("too many TerminalCondition plugins are registered")
35+
3236
type Framework struct {
3337
registry fwkplugins.Registry
3438
plugins map[string]framework.Plugin
@@ -37,6 +41,7 @@ type Framework struct {
3741
customValidationPlugins []framework.CustomValidationPlugin
3842
watchExtensionPlugins []framework.WatchExtensionPlugin
3943
componentBuilderPlugins []framework.ComponentBuilderPlugin
44+
terminalConditionPlugins []framework.TerminalConditionPlugin
4045
}
4146

4247
func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) {
@@ -66,6 +71,9 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl
6671
if p, ok := plugin.(framework.ComponentBuilderPlugin); ok {
6772
f.componentBuilderPlugins = append(f.componentBuilderPlugins, p)
6873
}
74+
if p, ok := plugin.(framework.TerminalConditionPlugin); ok {
75+
f.terminalConditionPlugins = append(f.terminalConditionPlugins, p)
76+
}
6977
}
7078
f.plugins = plugins
7179
return f, nil
@@ -118,6 +126,17 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, runtimeJobTe
118126
return objs, nil
119127
}
120128

129+
func (f *Framework) RunTerminalConditionPlugins(ctx context.Context, trainJob *kubeflowv2.TrainJob) (*metav1.Condition, error) {
130+
// TODO (tenzen-y): Once we provide the Configuration API, we should validate which plugin should have terminalCondition execution points.
131+
if len(f.terminalConditionPlugins) > 1 {
132+
return nil, errorTooManyTerminalConditionPlugin
133+
}
134+
if len(f.terminalConditionPlugins) != 0 {
135+
return f.terminalConditionPlugins[0].TerminalCondition(ctx, trainJob)
136+
}
137+
return nil, nil
138+
}
139+
121140
func (f *Framework) WatchExtensionPlugins() []framework.WatchExtensionPlugin {
122141
return f.watchExtensionPlugins
123142
}

0 commit comments

Comments
 (0)