Skip to content

Commit 736a759

Browse files
author
Akshay Chitneni
committed
Adding v2 trainjob validation webhook
fixing runtime
1 parent 9ed4112 commit 736a759

File tree

18 files changed

+392
-76
lines changed

18 files changed

+392
-76
lines changed

pkg/controller.v2/trainjob_controller.go

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@ package controllerv2
1818

1919
import (
2020
"context"
21-
"errors"
2221
"fmt"
22+
runtimeUtils "github.com/kubeflow/training-operator/pkg/util.v2/runtime"
2323

2424
"github.com/go-logr/logr"
2525
"k8s.io/apimachinery/pkg/runtime/schema"
2626
"k8s.io/client-go/tools/record"
2727
"k8s.io/klog/v2"
28-
"k8s.io/utils/ptr"
2928
ctrl "sigs.k8s.io/controller-runtime"
3029
"sigs.k8s.io/controller-runtime/pkg/client"
3130
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"
@@ -34,8 +33,6 @@ import (
3433
jobruntimes "github.com/kubeflow/training-operator/pkg/runtime.v2"
3534
)
3635

37-
var errorUnsupportedRuntime = errors.New("the specified runtime is not supported")
38-
3936
type TrainJobReconciler struct {
4037
log logr.Logger
4138
client client.Client
@@ -73,10 +70,10 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
7370
func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *kubeflowv2.TrainJob) error {
7471
log := ctrl.LoggerFrom(ctx)
7572

76-
runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
73+
runtimeRefGK := runtimeUtils.RuntimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
7774
runtime, ok := r.runtimes[runtimeRefGK]
7875
if !ok {
79-
return fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
76+
return fmt.Errorf("%w: %s", runtimeUtils.ErrorUnsupportedRuntime, runtimeRefGK)
8077
}
8178
objs, err := runtime.NewObjects(ctx, trainJob)
8279
if err != nil {
@@ -117,13 +114,6 @@ func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *k
117114
return nil
118115
}
119116

120-
func runtimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind {
121-
return schema.GroupKind{
122-
Group: ptr.Deref(runtimeRef.APIGroup, ""),
123-
Kind: ptr.Deref(runtimeRef.Kind, ""),
124-
}
125-
}
126-
127117
func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager) error {
128118
b := ctrl.NewControllerManagedBy(mgr).
129119
For(&kubeflowv2.TrainJob{})

pkg/runtime.v2/core/clustertrainingruntime.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,17 @@ func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBu
6464
}
6565

6666
func (r *ClusterTrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
67+
clusterTrainingRuntime := &kubeflowv2.ClusterTrainingRuntime{}
6768
if err := r.client.Get(ctx, client.ObjectKey{
68-
Namespace: old.Namespace,
69-
Name: old.Spec.RuntimeRef.Name,
70-
}, &kubeflowv2.ClusterTrainingRuntime{}); err != nil {
69+
Namespace: new.Namespace,
70+
Name: new.Spec.RuntimeRef.Name,
71+
}, clusterTrainingRuntime); err != nil {
7172
return nil, field.ErrorList{
72-
field.Invalid(field.NewPath("spec", "RuntimeRef"), old.Spec.RuntimeRef,
73+
field.Invalid(field.NewPath("spec", "RuntimeRef"), new.Spec.RuntimeRef,
7374
fmt.Sprintf("%v: specified clusterTrainingRuntime must be created before the TrainJob is created", err)),
7475
}
7576
}
76-
return r.framework.RunCustomValidationPlugins(old, new)
77+
info := r.getRuntimeInfo(ctx, new, clusterTrainingRuntime.Spec.Template, clusterTrainingRuntime.Spec.MLPolicy,
78+
clusterTrainingRuntime.Spec.PodGroupPolicy)
79+
return r.framework.RunCustomValidationPlugins(old, new, info)
7780
}

pkg/runtime.v2/core/trainingruntime.go

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,21 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *kubeflowv2.T
8484
func (r *TrainingRuntime) buildObjects(
8585
ctx context.Context, trainJob *kubeflowv2.TrainJob, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec, mlPolicy *kubeflowv2.MLPolicy, podGroupPolicy *kubeflowv2.PodGroupPolicy,
8686
) ([]client.Object, error) {
87+
88+
info := r.getRuntimeInfo(ctx, trainJob, jobSetTemplateSpec, mlPolicy, podGroupPolicy)
89+
if err := r.framework.RunEnforceMLPolicyPlugins(info); err != nil {
90+
return nil, err
91+
}
92+
err := r.framework.RunEnforcePodGroupPolicyPlugins(trainJob, info)
93+
if err != nil {
94+
return nil, err
95+
}
96+
return r.framework.RunComponentBuilderPlugins(ctx, info, trainJob)
97+
}
98+
99+
func (r *TrainingRuntime) getRuntimeInfo(
100+
ctx context.Context, trainJob *kubeflowv2.TrainJob, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec, mlPolicy *kubeflowv2.MLPolicy, podGroupPolicy *kubeflowv2.PodGroupPolicy) *runtime.Info {
101+
87102
propagationLabels := jobSetTemplateSpec.Labels
88103
if propagationLabels == nil && trainJob.Spec.Labels != nil {
89104
propagationLabels = make(map[string]string, len(trainJob.Spec.Labels))
@@ -118,14 +133,7 @@ func (r *TrainingRuntime) buildObjects(
118133
Spec: *jobSetTemplateSpec.Spec.DeepCopy(),
119134
}, opts...)
120135

121-
if err := r.framework.RunEnforceMLPolicyPlugins(info); err != nil {
122-
return nil, err
123-
}
124-
err := r.framework.RunEnforcePodGroupPolicyPlugins(trainJob, info)
125-
if err != nil {
126-
return nil, err
127-
}
128-
return r.framework.RunComponentBuilderPlugins(ctx, info, trainJob)
136+
return info
129137
}
130138

131139
func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
@@ -137,14 +145,16 @@ func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
137145
}
138146

139147
func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
148+
trainingRuntime := &kubeflowv2.TrainingRuntime{}
140149
if err := r.client.Get(ctx, client.ObjectKey{
141-
Namespace: old.Namespace,
142-
Name: old.Spec.RuntimeRef.Name,
143-
}, &kubeflowv2.TrainingRuntime{}); err != nil {
150+
Namespace: new.Namespace,
151+
Name: new.Spec.RuntimeRef.Name,
152+
}, trainingRuntime); err != nil {
144153
return nil, field.ErrorList{
145-
field.Invalid(field.NewPath("spec", "runtimeRef"), old.Spec.RuntimeRef,
154+
field.Invalid(field.NewPath("spec", "runtimeRef"), new.Spec.RuntimeRef,
146155
fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)),
147156
}
148157
}
149-
return r.framework.RunCustomValidationPlugins(old, new)
158+
info := r.getRuntimeInfo(ctx, new, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy)
159+
return r.framework.RunCustomValidationPlugins(old, new, info)
150160
}

pkg/runtime.v2/framework/core/framework.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,12 @@ func (f *Framework) RunEnforcePodGroupPolicyPlugins(trainJob *kubeflowv2.TrainJo
8989
return nil
9090
}
9191

92-
func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
92+
func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *kubeflowv2.TrainJob,
93+
runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList) {
9394
var aggregatedWarnings admission.Warnings
9495
var aggregatedErrors field.ErrorList
9596
for _, plugin := range f.customValidationPlugins {
96-
warnings, errs := plugin.Validate(oldObj, newObj)
97+
warnings, errs := plugin.Validate(oldObj, newObj, runtimeInfo)
9798
if len(warnings) != 0 {
9899
aggregatedWarnings = append(aggregatedWarnings, warnings...)
99100
}

pkg/runtime.v2/framework/core/framework_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ func TestNew(t *testing.T) {
8080
customValidationPlugins: []framework.CustomValidationPlugin{
8181
&mpi.MPI{},
8282
&torch.Torch{},
83+
&jobset.JobSet{},
8384
},
8485
watchExtensionPlugins: []framework.WatchExtensionPlugin{
8586
&coscheduling.CoScheduling{},
@@ -314,7 +315,8 @@ func TestRunCustomValidationPlugins(t *testing.T) {
314315
if err != nil {
315316
t.Fatal(err)
316317
}
317-
warnings, errs := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj)
318+
runtimeInfo := runtime.NewInfo(testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test").Obj())
319+
warnings, errs := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj, runtimeInfo)
318320
if diff := cmp.Diff(tc.wantWarnings, warnings, cmpopts.SortSlices(func(a, b string) bool { return a < b })); len(diff) != 0 {
319321
t.Errorf("Unexpected warninigs (-want,+got):\n%s", diff)
320322
}

pkg/runtime.v2/framework/interface.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ type EnforceMLPolicyPlugin interface {
4848

4949
type CustomValidationPlugin interface {
5050
Plugin
51-
Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList)
51+
Validate(oldObj, newObj *kubeflowv2.TrainJob, runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList)
5252
}
5353

5454
type ComponentBuilderPlugin interface {

pkg/runtime.v2/framework/plugins/jobset/jobset.go

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ package jobset
1919
import (
2020
"context"
2121
"fmt"
22+
"k8s.io/apimachinery/pkg/util/validation/field"
2223
"maps"
24+
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
2325

2426
"github.com/go-logr/logr"
2527
batchv1 "k8s.io/api/batch/v1"
@@ -50,6 +52,7 @@ type JobSet struct {
5052

5153
var _ framework.WatchExtensionPlugin = (*JobSet)(nil)
5254
var _ framework.ComponentBuilderPlugin = (*JobSet)(nil)
55+
var _ framework.CustomValidationPlugin = (*JobSet)(nil)
5356

5457
const Name = "JobSet"
5558

@@ -140,3 +143,115 @@ func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder {
140143
},
141144
}
142145
}
146+
147+
func (j *JobSet) Validate(oldObj, newObj *kubeflowv2.TrainJob, runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList) {
148+
149+
var allErrs field.ErrorList
150+
specPath := field.NewPath("spec")
151+
152+
jobSet, ok := runtimeInfo.Obj.(*jobsetv1alpha2.JobSet)
153+
if !ok {
154+
return nil, nil
155+
}
156+
157+
if newObj.Spec.ModelConfig != nil {
158+
// validate `model-initializer` container in the `Initializer` Job
159+
if newObj.Spec.ModelConfig.Input != nil {
160+
modelConfigInputPath := specPath.Child("modelConfig").Child("input")
161+
if len(jobSet.Spec.ReplicatedJobs) == 0 {
162+
allErrs = append(allErrs, field.Invalid(modelConfigInputPath, newObj.Spec.ModelConfig.Input, "trainingRuntime should have replicated jobs configured with model config input set"))
163+
} else {
164+
initializerJobFound := false
165+
modelInitializerContainerFound := false
166+
for _, job := range jobSet.Spec.ReplicatedJobs {
167+
if job.Name == "Initializer" {
168+
initializerJobFound = true
169+
for _, container := range job.Template.Spec.Template.Spec.Containers {
170+
if container.Name == "model-initializer" {
171+
modelInitializerContainerFound = true
172+
}
173+
}
174+
}
175+
}
176+
if !initializerJobFound {
177+
allErrs = append(allErrs, field.Invalid(modelConfigInputPath, newObj.Spec.ModelConfig.Input, "trainingRuntime should have replicated job configured with name - Initializer"))
178+
} else if !modelInitializerContainerFound {
179+
allErrs = append(allErrs, field.Invalid(modelConfigInputPath, newObj.Spec.ModelConfig.Input, "trainingRuntime with replicated job initializer should have container with name - model-initializer"))
180+
}
181+
}
182+
}
183+
184+
// validate `model-exporter` container in the `Exporter` Job
185+
if newObj.Spec.ModelConfig.Output != nil {
186+
modelConfigOutputPath := specPath.Child("modelConfig").Child("output")
187+
if len(jobSet.Spec.ReplicatedJobs) == 0 {
188+
allErrs = append(allErrs, field.Invalid(modelConfigOutputPath, newObj.Spec.ModelConfig.Output, "trainingRuntime should have replicated jobs configured with model config output set"))
189+
} else {
190+
exporterJobFound := false
191+
modelExporterContainerFound := false
192+
for _, job := range jobSet.Spec.ReplicatedJobs {
193+
if job.Name == "Exporter" {
194+
exporterJobFound = true
195+
for _, container := range job.Template.Spec.Template.Spec.Containers {
196+
if container.Name == "model-exporter" {
197+
modelExporterContainerFound = true
198+
}
199+
}
200+
}
201+
}
202+
if !exporterJobFound {
203+
allErrs = append(allErrs, field.Invalid(modelConfigOutputPath, newObj.Spec.ModelConfig.Output, "trainingRuntime should have replicated job configured with name - Exporter"))
204+
} else if !modelExporterContainerFound {
205+
allErrs = append(allErrs, field.Invalid(modelConfigOutputPath, newObj.Spec.ModelConfig.Output, "trainingRuntime with replicated job Exporter should have container with name - model-exporter"))
206+
}
207+
}
208+
}
209+
}
210+
211+
if len(newObj.Spec.PodSpecOverrides) != 0 {
212+
podSpecOverridesPath := specPath.Child("podSpecOverrides")
213+
jobsMap := map[string]bool{}
214+
for _, job := range jobSet.Spec.ReplicatedJobs {
215+
jobsMap[job.Name] = true
216+
}
217+
// validate if jobOverrides are valid
218+
for idx, override := range newObj.Spec.PodSpecOverrides {
219+
for _, job := range override.TargetJobs {
220+
if _, found := jobsMap[job.Name]; !found {
221+
allErrs = append(allErrs, field.Invalid(podSpecOverridesPath, newObj.Spec.PodSpecOverrides, fmt.Sprintf("job: %s, configured in the podOverride should be present in the referenced training runtime", job)))
222+
}
223+
}
224+
if len(override.Containers) != 0 {
225+
// validate if containerOverrides are valid
226+
containerMap := map[string]bool{}
227+
for _, job := range jobSet.Spec.ReplicatedJobs {
228+
for _, container := range job.Template.Spec.Template.Spec.Containers {
229+
containerMap[container.Name] = true
230+
}
231+
}
232+
containerOverridePath := podSpecOverridesPath.Index(idx)
233+
for _, container := range override.Containers {
234+
if _, found := containerMap[container.Name]; !found {
235+
allErrs = append(allErrs, field.Invalid(containerOverridePath, override.Containers, fmt.Sprintf("container: %s, configured in the containerOverride should be present in the referenced training runtime", container.Name)))
236+
}
237+
}
238+
}
239+
if len(override.InitContainers) != 0 {
240+
// validate if initContainerOverrides are valid
241+
initContainerMap := map[string]bool{}
242+
for _, job := range jobSet.Spec.ReplicatedJobs {
243+
for _, initContainer := range job.Template.Spec.Template.Spec.InitContainers {
244+
initContainerMap[initContainer.Name] = true
245+
}
246+
}
247+
initContainerOverridePath := podSpecOverridesPath.Index(idx)
248+
for _, container := range override.Containers {
249+
if _, found := initContainerMap[container.Name]; !found {
250+
allErrs = append(allErrs, field.Invalid(initContainerOverridePath, override.InitContainers, fmt.Sprintf("initContainer: %s, configured in the initContainerOverride should be present in the referenced training runtime", container.Name)))
251+
}
252+
}
253+
}
254+
}
255+
}
256+
return nil, allErrs
257+
}

pkg/runtime.v2/framework/plugins/mpi/mpi.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package mpi
1818

1919
import (
2020
"context"
21+
"strconv"
2122

2223
"k8s.io/apimachinery/pkg/util/validation/field"
2324
"sigs.k8s.io/controller-runtime/pkg/client"
@@ -55,7 +56,16 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info) error {
5556
return nil
5657
}
5758

58-
// TODO: Need to implement validations for MPIJob.
59-
func (m *MPI) Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
60-
return nil, nil
59+
func (m *MPI) Validate(oldJobObj, newJobObj *kubeflowv2.TrainJob, runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList) {
60+
var allErrs field.ErrorList
61+
specPath := field.NewPath("spec")
62+
if newJobObj.Spec.Trainer != nil {
63+
numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode")
64+
if runtimeInfo.MLPolicy.MPI != nil {
65+
if _, err := strconv.Atoi(*newJobObj.Spec.Trainer.NumProcPerNode); err != nil {
66+
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newJobObj.Spec.Trainer.NumProcPerNode, "should have an int value"))
67+
}
68+
}
69+
}
70+
return nil, allErrs
6171
}

pkg/runtime.v2/framework/plugins/torch/torch.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ package torch
1818

1919
import (
2020
"context"
21+
"k8s.io/utils/strings/slices"
22+
"strconv"
2123

2224
"k8s.io/apimachinery/pkg/util/validation/field"
2325
"sigs.k8s.io/controller-runtime/pkg/client"
@@ -51,7 +53,20 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info) error {
5153
return nil
5254
}
5355

54-
// TODO: Need to implement validateions for TorchJob.
55-
func (t *Torch) Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
56-
return nil, nil
56+
func (t *Torch) Validate(oldObj, newObj *kubeflowv2.TrainJob, runtimeInfo *runtime.Info) (admission.Warnings, field.ErrorList) {
57+
var allErrs field.ErrorList
58+
specPath := field.NewPath("spec")
59+
60+
if newObj.Spec.Trainer != nil {
61+
numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode")
62+
if runtimeInfo.MLPolicy.Torch != nil {
63+
allowedStringValList := []string{"auto", "cpu", "gpu"}
64+
if !slices.Contains(allowedStringValList, *newObj.Spec.Trainer.NumProcPerNode) {
65+
if _, err := strconv.Atoi(*newObj.Spec.Trainer.NumProcPerNode); err != nil {
66+
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newObj.Spec.Trainer.NumProcPerNode, "should have an int value or auto/cpu/gpu"))
67+
}
68+
}
69+
}
70+
}
71+
return nil, allErrs
5772
}

pkg/util.v2/runtime/runtime.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package runtime
2+
3+
import (
4+
"errors"
5+
kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
6+
"k8s.io/apimachinery/pkg/runtime/schema"
7+
"k8s.io/utils/ptr"
8+
)
9+
10+
var ErrorUnsupportedRuntime = errors.New("the specified runtime is not supported")
11+
12+
func RuntimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind {
13+
return schema.GroupKind{
14+
Group: ptr.Deref(runtimeRef.APIGroup, ""),
15+
Kind: ptr.Deref(runtimeRef.Kind, ""),
16+
}
17+
}

0 commit comments

Comments
 (0)