Skip to content

Commit 287daaa

Browse files
Add support for native preemption retries
Signed-off-by: Jason Parraga <[email protected]>
1 parent 90281f3 commit 287daaa

File tree

16 files changed

+619
-39
lines changed

16 files changed

+619
-39
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package preemption
2+
3+
import (
4+
"strconv"
5+
6+
"github.com/armadaproject/armada/internal/server/configuration"
7+
)
8+
9+
// AreRetriesEnabled determines whether preemption retries are enabled at the job level. Also returns whether the
10+
// annotation was set.
11+
func AreRetriesEnabled(annotations map[string]string) (enabled bool, annotationSet bool) {
12+
preemptionRetryEnabledStr, exists := annotations[configuration.PreemptionRetryEnabledAnnotation]
13+
if !exists {
14+
return false, false
15+
}
16+
17+
preemptionRetryEnabled, err := strconv.ParseBool(preemptionRetryEnabledStr)
18+
if err != nil {
19+
return false, true
20+
}
21+
return preemptionRetryEnabled, true
22+
}
23+
24+
// GetMaxRetryCount gets the max preemption retry count at a job level. Also returns whether the annotation was set.
25+
func GetMaxRetryCount(annotations map[string]string) (maxRetryCount uint, annotationSet bool) {
26+
var preemptionRetryCountMax uint = 0
27+
preemptionRetryCountMaxStr, exists := annotations[configuration.PreemptionRetryCountMaxAnnotation]
28+
29+
if !exists {
30+
return preemptionRetryCountMax, false
31+
}
32+
maybePreemptionRetryCountMax, err := strconv.Atoi(preemptionRetryCountMaxStr)
33+
if err != nil {
34+
return preemptionRetryCountMax, true
35+
} else {
36+
preemptionRetryCountMax = uint(maybePreemptionRetryCountMax)
37+
return preemptionRetryCountMax, true
38+
}
39+
}

internal/scheduler/configuration/configuration.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@ type SchedulingConfig struct {
248248
Pools []PoolConfig
249249
ExperimentalIndicativePricing ExperimentalIndicativePricing
250250
ExperimentalIndicativeShare ExperimentalIndicativeShare
251+
// Default preemption retries settings so you don't have to annotate all jobs with retries.
252+
DefaultPreemptionRetry PreemptionRetryConfig
251253
}
252254

253255
const (
@@ -354,3 +356,8 @@ type PriorityOverrideConfig struct {
354356
ServiceUrl string
355357
ForceNoTls bool
356358
}
359+
360+
type PreemptionRetryConfig struct {
361+
Enabled bool
362+
DefaultMaxRetryCount *uint
363+
}

internal/scheduler/database/job_repository.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ type JobRepository interface {
5454
CountReceivedPartitions(ctx *armadacontext.Context, groupId uuid.UUID) (uint32, error)
5555

5656
// FindInactiveRuns returns a slice containing all dbRuns that the scheduler does not currently consider active
57-
// Runs are inactive if they don't exist or if they have succeeded, failed or been cancelled
57+
// Runs are inactive if they don't exist or if they have succeeded, failed, preempted or been cancelled
5858
FindInactiveRuns(ctx *armadacontext.Context, runIds []string) ([]string, error)
5959

6060
// FetchJobRunLeases fetches new job runs for a given executor. A maximum of maxResults rows will be returned, while run
@@ -293,7 +293,7 @@ func (r *PostgresJobRepository) FetchJobUpdates(ctx *armadacontext.Context, jobS
293293
}
294294

295295
// FindInactiveRuns returns a slice containing all dbRuns that the scheduler does not currently consider active
296-
// Runs are inactive if they don't exist or if they have succeeded, failed or been cancelled
296+
// Runs are inactive if they don't exist or if they have succeeded, failed, preempted or been cancelled
297297
func (r *PostgresJobRepository) FindInactiveRuns(ctx *armadacontext.Context, runIds []string) ([]string, error) {
298298
var inactiveRuns []string
299299
err := pgx.BeginTxFunc(ctx, r.db, pgx.TxOptions{
@@ -313,6 +313,7 @@ func (r *PostgresJobRepository) FindInactiveRuns(ctx *armadacontext.Context, run
313313
WHERE runs.run_id IS NULL
314314
OR runs.succeeded = true
315315
OR runs.failed = true
316+
OR runs.preempted = true
316317
OR runs.cancelled = true;`
317318

318319
rows, err := tx.Query(ctx, fmt.Sprintf(query, tmpTable))
@@ -361,6 +362,7 @@ func (r *PostgresJobRepository) FetchJobRunLeases(ctx *armadacontext.Context, ex
361362
AND jr.succeeded = false
362363
AND jr.failed = false
363364
AND jr.cancelled = false
365+
AND jr.preempted = false
364366
ORDER BY jr.serial
365367
LIMIT %d;
366368
`

internal/scheduler/database/job_repository_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,15 @@ func TestFindInactiveRuns(t *testing.T) {
553553
},
554554
expectedInactive: []string{runIds[1]},
555555
},
556+
"run preempted": {
557+
runsToCheck: runIds,
558+
dbRuns: []Run{
559+
{RunID: runIds[0]},
560+
{RunID: runIds[1], Preempted: true},
561+
{RunID: runIds[2]},
562+
},
563+
expectedInactive: []string{runIds[1]},
564+
},
556565
"run missing": {
557566
runsToCheck: runIds,
558567
dbRuns: []Run{
@@ -654,6 +663,14 @@ func TestFetchJobRunLeases(t *testing.T) {
654663
Pool: "test-pool",
655664
Succeeded: true, // should be ignored as terminal
656665
},
666+
{
667+
RunID: uuid.NewString(),
668+
JobID: dbJobs[0].JobID,
669+
JobSet: "test-jobset",
670+
Executor: executorName,
671+
Pool: "test-pool",
672+
Preempted: true, // should be ignored as terminal
673+
},
657674
}
658675
expectedLeases := make([]*JobRunLease, 4)
659676
for i := range expectedLeases {

internal/scheduler/jobdb/job.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ import (
44
"fmt"
55
"time"
66

7+
"github.com/armadaproject/armada/internal/common/preemption"
8+
"github.com/armadaproject/armada/internal/scheduler/configuration"
9+
710
"github.com/hashicorp/go-multierror"
811
"github.com/pkg/errors"
912
"golang.org/x/exp/maps"
@@ -734,6 +737,59 @@ func (job *Job) NumAttempts() uint {
734737
return attempts
735738
}
736739

740+
// IsEligibleForPreemptionRetry determines whether the job is eligible for preemption retries. It checks whether the
741+
// scheduler or the job has opted in for preemption retries. It then checks whether the job has exhausted the number
742+
// of retries.
743+
func (job *Job) IsEligibleForPreemptionRetry(defaultPreemptionRetryConfig configuration.PreemptionRetryConfig) bool {
744+
enabled := false
745+
746+
// Check for platform default first
747+
if defaultPreemptionRetryConfig.Enabled {
748+
enabled = true
749+
}
750+
751+
// Check if job explicitly enabled/disabled retries
752+
jobRetryEnabled, exists := preemption.AreRetriesEnabled(job.Annotations())
753+
if exists {
754+
enabled = jobRetryEnabled
755+
}
756+
757+
if !enabled {
758+
return false
759+
}
760+
761+
maxRetryCount := job.MaxPreemptionRetryCount(defaultPreemptionRetryConfig)
762+
763+
return job.NumPreemptedRuns() <= maxRetryCount
764+
}
765+
766+
func (job *Job) NumPreemptedRuns() uint {
767+
preemptCount := uint(0)
768+
for _, run := range job.runsById {
769+
if run.preempted {
770+
preemptCount++
771+
}
772+
}
773+
return preemptCount
774+
}
775+
776+
func (job *Job) MaxPreemptionRetryCount(defaultPreemptionRetryConfig configuration.PreemptionRetryConfig) uint {
777+
var maxRetryCount uint = 0
778+
779+
// Check for platform default first
780+
if defaultPreemptionRetryConfig.DefaultMaxRetryCount != nil {
781+
platformDefaultMaxRetryCount := *defaultPreemptionRetryConfig.DefaultMaxRetryCount
782+
maxRetryCount = platformDefaultMaxRetryCount
783+
}
784+
785+
// Allow jobs to set a custom max retry count
786+
jobMaxRetryCount, exists := preemption.GetMaxRetryCount(job.Annotations())
787+
if exists {
788+
maxRetryCount = jobMaxRetryCount
789+
}
790+
return maxRetryCount
791+
}
792+
737793
// AllRuns returns all runs associated with job.
738794
func (job *Job) AllRuns() []*JobRun {
739795
return maps.Values(job.runsById)

internal/scheduler/jobdb/job_run.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ func (run *JobRun) WithoutTerminal() *JobRun {
501501

502502
// InTerminalState returns true if the JobRun is in a terminal state
503503
func (run *JobRun) InTerminalState() bool {
504-
return run.succeeded || run.failed || run.cancelled || run.returned
504+
return run.succeeded || run.failed || run.cancelled || run.returned || run.preempted
505505
}
506506

507507
func (run *JobRun) DeepCopy() *JobRun {

internal/scheduler/jobdb/job_test.go

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ import (
1010

1111
"github.com/armadaproject/armada/internal/common/types"
1212
"github.com/armadaproject/armada/internal/scheduler/internaltypes"
13+
14+
configuration2 "github.com/armadaproject/armada/internal/scheduler/configuration"
15+
"github.com/armadaproject/armada/internal/server/configuration"
1316
)
1417

1518
var jobSchedulingInfo = &internaltypes.JobSchedulingInfo{
@@ -26,6 +29,35 @@ var jobSchedulingInfo = &internaltypes.JobSchedulingInfo{
2629
},
2730
}
2831

32+
var jobSchedulingInfoWithRetryEnabled = &internaltypes.JobSchedulingInfo{
33+
PodRequirements: &internaltypes.PodRequirements{
34+
ResourceRequirements: v1.ResourceRequirements{
35+
Requests: v1.ResourceList{
36+
"cpu": k8sResource.MustParse("1"),
37+
"storage-connections": k8sResource.MustParse("1"),
38+
},
39+
},
40+
Annotations: map[string]string{
41+
configuration.PreemptionRetryEnabledAnnotation: "true",
42+
configuration.PreemptionRetryCountMaxAnnotation: "1",
43+
},
44+
},
45+
}
46+
47+
var jobSchedulingInfoWithRetryDisabled = &internaltypes.JobSchedulingInfo{
48+
PodRequirements: &internaltypes.PodRequirements{
49+
ResourceRequirements: v1.ResourceRequirements{
50+
Requests: v1.ResourceList{
51+
"cpu": k8sResource.MustParse("1"),
52+
"storage-connections": k8sResource.MustParse("1"),
53+
},
54+
},
55+
Annotations: map[string]string{
56+
configuration.PreemptionRetryEnabledAnnotation: "false",
57+
},
58+
},
59+
}
60+
2961
var baseJob, _ = jobDb.NewJob(
3062
"test-job",
3163
"test-jobSet",
@@ -42,6 +74,38 @@ var baseJob, _ = jobDb.NewJob(
4274
[]string{},
4375
)
4476

77+
var baseJobWithRetryEnabled, _ = jobDb.NewJob(
78+
"test-job",
79+
"test-jobSet",
80+
"test-queue",
81+
2,
82+
jobSchedulingInfoWithRetryEnabled,
83+
true,
84+
0,
85+
false,
86+
false,
87+
false,
88+
3,
89+
false,
90+
[]string{},
91+
)
92+
93+
var baseJobWithRetryDisabled, _ = jobDb.NewJob(
94+
"test-job",
95+
"test-jobSet",
96+
"test-queue",
97+
2,
98+
jobSchedulingInfoWithRetryDisabled,
99+
true,
100+
0,
101+
false,
102+
false,
103+
false,
104+
3,
105+
false,
106+
[]string{},
107+
)
108+
45109
var baseRun = &JobRun{
46110
id: uuid.New().String(),
47111
created: 3,
@@ -425,3 +489,85 @@ func TestJob_TestKubernetesResourceRequirements(t *testing.T) {
425489
assert.Equal(t, int64(1000), baseJob.KubernetesResourceRequirements().GetByNameZeroIfMissing("cpu"))
426490
assert.Equal(t, int64(0), baseJob.KubernetesResourceRequirements().GetByNameZeroIfMissing("storage-connections"))
427491
}
492+
493+
func TestIsEligibleForPreemptionRetry(t *testing.T) {
494+
premptedRun1 := &JobRun{
495+
id: uuid.New().String(),
496+
created: 3,
497+
executor: "test-executor",
498+
preempted: true,
499+
}
500+
501+
premptedRun2 := &JobRun{
502+
id: uuid.New().String(),
503+
created: 5,
504+
executor: "test-executor",
505+
preempted: true,
506+
}
507+
508+
defaultMaxRetryCountEnabled := uint(5)
509+
platformDefaultEnabled := configuration2.PreemptionRetryConfig{
510+
Enabled: true,
511+
DefaultMaxRetryCount: &defaultMaxRetryCountEnabled,
512+
}
513+
514+
defaultMaxRetryCountDisabled := uint(0)
515+
platformDefaultDisabled := configuration2.PreemptionRetryConfig{
516+
Enabled: false,
517+
DefaultMaxRetryCount: &defaultMaxRetryCountDisabled,
518+
}
519+
520+
// no runs
521+
t.Run("job with retry enabled and platform disabled and no runs", func(t *testing.T) {
522+
assert.True(t, baseJobWithRetryEnabled.IsEligibleForPreemptionRetry(platformDefaultDisabled))
523+
})
524+
525+
t.Run("job with retry disabled and platform enabled and no runs", func(t *testing.T) {
526+
assert.False(t, baseJobWithRetryDisabled.IsEligibleForPreemptionRetry(platformDefaultEnabled))
527+
})
528+
529+
t.Run("job with platform retry enabled and no runs", func(t *testing.T) {
530+
assert.True(t, baseJob.IsEligibleForPreemptionRetry(platformDefaultEnabled))
531+
})
532+
533+
t.Run("job with platform retry disabled and no runs", func(t *testing.T) {
534+
assert.False(t, baseJob.IsEligibleForPreemptionRetry(platformDefaultDisabled))
535+
})
536+
537+
// runs but none are preempted
538+
t.Run("job with retry enabled and platform disabled runs but no preempted runs", func(t *testing.T) {
539+
updatedJob := baseJobWithRetryEnabled.WithUpdatedRun(baseRun).WithUpdatedRun(baseRun)
540+
assert.True(t, updatedJob.IsEligibleForPreemptionRetry(platformDefaultDisabled))
541+
})
542+
543+
t.Run("job with platform enabled runs but no preempted runs", func(t *testing.T) {
544+
updatedJob := baseJob.WithUpdatedRun(baseRun).WithUpdatedRun(baseRun)
545+
assert.True(t, updatedJob.IsEligibleForPreemptionRetry(platformDefaultEnabled))
546+
})
547+
548+
t.Run("job with retry enabled and platform disabled and one run", func(t *testing.T) {
549+
updatedJob := baseJobWithRetryEnabled.WithUpdatedRun(premptedRun1)
550+
assert.True(t, updatedJob.IsEligibleForPreemptionRetry(platformDefaultDisabled))
551+
})
552+
553+
t.Run("job with platform enabled and one run", func(t *testing.T) {
554+
updatedJob := baseJob.WithUpdatedRun(premptedRun1)
555+
assert.True(t, updatedJob.IsEligibleForPreemptionRetry(platformDefaultEnabled))
556+
})
557+
558+
// runs that are preempted
559+
t.Run("job with retry enabled and platform disabled and out of retries", func(t *testing.T) {
560+
updatedJob := baseJobWithRetryEnabled.WithUpdatedRun(premptedRun1).WithUpdatedRun(premptedRun2)
561+
assert.False(t, updatedJob.IsEligibleForPreemptionRetry(platformDefaultDisabled))
562+
})
563+
564+
t.Run("job with retry enabled with platform enabled and out of retries", func(t *testing.T) {
565+
updatedJob := baseJobWithRetryEnabled.WithUpdatedRun(premptedRun1).WithUpdatedRun(premptedRun2)
566+
assert.False(t, updatedJob.IsEligibleForPreemptionRetry(platformDefaultEnabled))
567+
})
568+
569+
t.Run("job with platform enabled and retries left", func(t *testing.T) {
570+
updatedJob := baseJob.WithUpdatedRun(premptedRun1).WithUpdatedRun(premptedRun2)
571+
assert.True(t, updatedJob.IsEligibleForPreemptionRetry(platformDefaultEnabled))
572+
})
573+
}

0 commit comments

Comments
 (0)