Skip to content

Commit 7591f89

Browse files
author
Jason Parraga
committed
Add support for native preemption retries
Signed-off-by: Jason Parraga <[email protected]>
1 parent d85ae70 commit 7591f89

File tree

21 files changed

+639
-27
lines changed

21 files changed

+639
-27
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package preemption
2+
3+
import (
4+
"github.com/armadaproject/armada/internal/server/configuration"
5+
"strconv"
6+
)
7+
8+
// AreRetriesEnabled determines whether preemption retries are enabled at the job level. Also returns whether the
9+
// annotation was set.
10+
func AreRetriesEnabled(annotations map[string]string) (bool, bool) {
11+
preemptionRetryEnabledStr, exists := annotations[configuration.PreemptionRetryEnabledAnnotation]
12+
if !exists {
13+
return false, false
14+
}
15+
16+
preemptionRetryEnabled, err := strconv.ParseBool(preemptionRetryEnabledStr)
17+
if err != nil {
18+
return false, true
19+
}
20+
return preemptionRetryEnabled, true
21+
}
22+
23+
// GetMaxRetryCount gets the max preemption retry count at a job level. Also returns whether the annotation was set.
24+
func GetMaxRetryCount(annotations map[string]string) (uint, bool) {
25+
var preemptionRetryCountMax uint = 0
26+
preemptionRetryCountMaxStr, exists := annotations[configuration.PreemptionRetryCountMaxAnnotation]
27+
28+
if !exists {
29+
return 0, false
30+
}
31+
maybePreemptionRetryCountMax, err := strconv.Atoi(preemptionRetryCountMaxStr)
32+
if err != nil {
33+
return 0, true
34+
} else {
35+
preemptionRetryCountMax = uint(maybePreemptionRetryCountMax)
36+
}
37+
38+
return preemptionRetryCountMax, true
39+
}

internal/executor/service/job_state_reporter.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ func (stateReporter *JobStateReporter) reportCurrentStatus(pod *v1.Pod) {
9393
}
9494

9595
if pod.Status.Phase == v1.PodFailed {
96+
97+
if util.IsPodPreempted(pod) {
98+
return
99+
}
100+
96101
hasIssue := stateReporter.podIssueHandler.HasIssue(util.ExtractJobRunId(pod))
97102
if hasIssue {
98103
// Pod already being handled by issue handler

internal/executor/service/job_state_reporter_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,27 @@ func TestJobStateReporter_HandlesFailedPod_WithRetryableError(t *testing.T) {
166166
assertExpectedEvents(t, before, eventReporter.ReceivedEvents, reflect.TypeOf(&armadaevents.EventSequence_Event_JobRunErrors{}))
167167
}
168168

169+
func TestJobStateReporter_IgnoresPreemptedPod(t *testing.T) {
170+
_, _, eventReporter, fakeClusterContext := setUpJobStateReporterTest(t)
171+
172+
before := makeTestPod(v1.PodStatus{Phase: v1.PodRunning})
173+
after := before.DeepCopy()
174+
after.Status = v1.PodStatus{
175+
Phase: v1.PodFailed,
176+
Conditions: []v1.PodCondition{
177+
{
178+
Type: v1.DisruptionTarget,
179+
Status: v1.ConditionTrue,
180+
Reason: util.PreemptedReason,
181+
},
182+
},
183+
}
184+
185+
fakeClusterContext.SimulateUpdateAddEvent(before, after)
186+
time.Sleep(time.Millisecond * 100) // Give time for async routine to process message
187+
assert.Len(t, eventReporter.ReceivedEvents, 0)
188+
}
189+
169190
func setUpJobStateReporterTest(t *testing.T) (*JobStateReporter, *stubIssueHandler, *mocks.FakeEventReporter, *fakecontext.SyncFakeClusterContext) {
170191
fakeClusterContext := fakecontext.NewSyncFakeClusterContext()
171192
eventReporter := mocks.NewFakeEventReporter()

internal/executor/service/pod_issue_handler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ func createStuckPodMessage(retryable bool, originalMessage string) string {
521521
func (p *PodIssueHandler) handleDeletedPod(pod *v1.Pod) {
522522
jobId := util.ExtractJobId(pod)
523523
if jobId != "" {
524-
isUnexpectedDeletion := !util.IsMarkedForDeletion(pod) && !util.IsPodFinishedAndReported(pod)
524+
isUnexpectedDeletion := !util.IsMarkedForDeletion(pod) && !util.IsPodFinishedAndReported(pod) && !util.IsPodPreempted(pod)
525525
if isUnexpectedDeletion {
526526
p.attemptToRegisterIssue(&runIssue{
527527
JobId: jobId,

internal/executor/service/pod_issue_handler_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,27 @@ func TestPodIssueService_ReportsFailed_IfDeletedExternally(t *testing.T) {
329329
assert.Equal(t, jobId, failedEvent.JobRunErrors.JobId)
330330
}
331331

332+
func TestPodIssueService_IgnoresFailed_IfDeletedExternallyDueToPreemption(t *testing.T) {
333+
podIssueService, _, fakeClusterContext, eventsReporter, err := setupTestComponents([]*job.RunState{})
334+
require.NoError(t, err)
335+
preemptedPod := makeTestPod(v1.PodStatus{
336+
Phase: v1.PodFailed,
337+
Conditions: []v1.PodCondition{
338+
{
339+
Type: v1.DisruptionTarget,
340+
Status: v1.ConditionTrue,
341+
Reason: util.PreemptedReason,
342+
},
343+
},
344+
})
345+
fakeClusterContext.SimulateDeletionEvent(preemptedPod)
346+
347+
podIssueService.HandlePodIssues()
348+
349+
assert.Len(t, eventsReporter.ReceivedEvents, 0)
350+
assert.Len(t, podIssueService.knownPodIssues, 0)
351+
}
352+
332353
func TestPodIssueService_ReportsFailed_IfPodOfActiveRunGoesMissing(t *testing.T) {
333354
baseTime := time.Now()
334355
fakeClock := clock.NewFakeClock(baseTime)

internal/executor/util/pod_status.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ const (
1515
oomKilledReason = "OOMKilled"
1616
evictedReason = "Evicted"
1717
deadlineExceeded = "DeadlineExceeded"
18+
PreemptedReason = "PreemptionByScheduler"
1819
)
1920

2021
// TODO: Need to detect pod preemption. So that job failed events can include a string indicating a pod was preempted.

internal/executor/util/pod_util.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,15 @@ func GroupByQueue(pods []*v1.Pod) map[string][]*v1.Pod {
370370
}
371371
return podsByQueue
372372
}
373+
374+
func IsPodPreempted(pod *v1.Pod) bool {
375+
for _, containerCondition := range pod.Status.Conditions {
376+
if containerCondition.Type == v1.DisruptionTarget && containerCondition.Status == v1.ConditionTrue {
377+
if containerCondition.Reason == PreemptedReason {
378+
return true
379+
}
380+
}
381+
}
382+
383+
return false
384+
}

internal/executor/util/pod_util_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,55 @@ import (
1313
"github.com/armadaproject/armada/internal/server/configuration"
1414
)
1515

16+
func TestIsPodPreempted(t *testing.T) {
17+
18+
t.Run("preempted pod is preempted", func(t *testing.T) {
19+
pod := v1.Pod{
20+
Status: v1.PodStatus{
21+
Phase: v1.PodFailed,
22+
Conditions: []v1.PodCondition{
23+
{
24+
Type: v1.DisruptionTarget,
25+
Status: v1.ConditionTrue,
26+
Reason: PreemptedReason,
27+
},
28+
},
29+
},
30+
}
31+
32+
assert.True(t, IsPodPreempted(&pod))
33+
})
34+
35+
t.Run("failed pod is not preempted", func(t *testing.T) {
36+
pod := v1.Pod{
37+
Status: v1.PodStatus{
38+
Phase: v1.PodFailed,
39+
ContainerStatuses: []v1.ContainerStatus{
40+
{
41+
State: v1.ContainerState{
42+
Terminated: &v1.ContainerStateTerminated{
43+
ExitCode: 1,
44+
},
45+
},
46+
},
47+
},
48+
},
49+
}
50+
51+
assert.False(t, IsPodPreempted(&pod))
52+
})
53+
54+
t.Run("successful pod is not preempted", func(t *testing.T) {
55+
pod := v1.Pod{
56+
Status: v1.PodStatus{
57+
Phase: v1.PodSucceeded,
58+
},
59+
}
60+
61+
assert.False(t, IsPodPreempted(&pod))
62+
})
63+
}
64+
1665
func TestIsInTerminalState_ShouldReturnTrueWhenPodInSucceededPhase(t *testing.T) {
1766
pod := v1.Pod{
1867
Status: v1.PodStatus{

internal/scheduler/configuration/configuration.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@ type SchedulingConfig struct {
248248
Pools []PoolConfig
249249
ExperimentalIndicativePricing ExperimentalIndicativePricing
250250
ExperimentalIndicativeShare ExperimentalIndicativeShare
251+
// Default preemption retries settings so you don't have to annotate all jobs with reties.
252+
DefaultPreemptionRetry PreemptionRetryConfig
251253
}
252254

253255
const (
@@ -354,3 +356,8 @@ type PriorityOverrideConfig struct {
354356
ServiceUrl string
355357
ForceNoTls bool
356358
}
359+
360+
type PreemptionRetryConfig struct {
361+
Enabled bool
362+
DefaultMaxRetryCount *uint
363+
}

internal/scheduler/database/job_repository.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ func (r *PostgresJobRepository) FindInactiveRuns(ctx *armadacontext.Context, run
283283
WHERE runs.run_id IS NULL
284284
OR runs.succeeded = true
285285
OR runs.failed = true
286+
OR runs.preempted = true
286287
OR runs.cancelled = true;`
287288

288289
rows, err := tx.Query(ctx, fmt.Sprintf(query, tmpTable))
@@ -331,6 +332,7 @@ func (r *PostgresJobRepository) FetchJobRunLeases(ctx *armadacontext.Context, ex
331332
AND jr.succeeded = false
332333
AND jr.failed = false
333334
AND jr.cancelled = false
335+
AND jr.preempted = false
334336
ORDER BY jr.serial
335337
LIMIT %d;
336338
`

0 commit comments

Comments
 (0)