Skip to content

Commit

Permalink
Implement retry util (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaitanyaKulkarni28 authored Feb 7, 2024
1 parent bc138e1 commit 3e94187
Show file tree
Hide file tree
Showing 2 changed files with 300 additions and 0 deletions.
107 changes: 107 additions & 0 deletions retry/retry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright 2024 Google LLC

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// https://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package retry implements retry logic helpers to execute arbitrary functions with defined policy.
package retry

import (
"context"
"fmt"
"math"
"time"

"github.com/GoogleCloudPlatform/guest-logging-go/logger"
)

// IsRetriable is method signature for implementing to override default logic of retrying each error.
type IsRetriable func(error) bool

// Policy represents the struct to configure the retry behavior.
type Policy struct {
// MaxAttempts represents the maximum number of retry attempts.
MaxAttempts int
// BackoffFactor is the multiplier by which retry interval (Jitter) increases after each retry.
// For constant backoff set Backoff factor to 1.
BackoffFactor float64
// Jitter is the interval before the first retry.
Jitter time.Duration
// ShouldRetry is optional and the way to override default retry logic of retry every error.
// If ShouldRetry is not provided/implemented every error will be retried until all attempts are exhausted.
ShouldRetry IsRetriable
}

// backoff computes interval between retries. Interval is jitter*(backoffFactor^attempt).
// For e.g. if jitter was set to 10 and factor was 3, backoff between attempts would be [10, 30, 90, 270...].
func backoff(attempt int, policy Policy) time.Duration {
b := float64(policy.Jitter) * math.Pow(policy.BackoffFactor, float64(attempt))
return time.Duration(b)
}

// isRetriable checks if error is retriable. If ShouldRetry is unimplemented it always returns
// true, otherwise overriden method's logic determines the retry behavior.
func isRetriable(policy Policy, err error) bool {
if policy.ShouldRetry == nil {
return true
}
return policy.ShouldRetry(err)
}

// RunWithResponse executes and retries the function on failure based on policy defined and returns response on success.
func RunWithResponse[T any](ctx context.Context, policy Policy, f func() (T, error)) (T, error) {
var (
res T
err error
)

if f == nil {
return res, fmt.Errorf("retry function cannot be nil")
}

for attempt := 0; attempt < policy.MaxAttempts; attempt++ {
if res, err = f(); err == nil {
return res, nil
}

if err != nil && !isRetriable(policy, err) {
return res, fmt.Errorf("giving up, retry policy returned false on error: %+v", err)
}

logger.Debugf("Attempt %d failed with error %+v", attempt, err)

// Return early, no need to wait if all retries have exhausted.
if attempt+1 >= policy.MaxAttempts {
return res, fmt.Errorf("exhausted all (%d) retries, last error: %+v", policy.MaxAttempts, err)
}

select {
case <-ctx.Done():
return res, ctx.Err()
case <-time.After(backoff(attempt, policy)):
}
}
return res, fmt.Errorf("num of retries set to 0, made no attempts to run")
}

// Run executes and retries the function on failure based on policy defined and returns nil-error on success.
func Run(ctx context.Context, policy Policy, f func() error) error {
if f == nil {
return fmt.Errorf("retry function cannot be nil")
}

fn := func() (any, error) {
return nil, f()
}
_, err := RunWithResponse(ctx, policy, fn)
return err
}
193 changes: 193 additions & 0 deletions retry/retry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
// Copyright 2024 Google LLC

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// https://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package retry

import (
"context"
"errors"
"fmt"
"testing"
"time"
)

func TestRetry(t *testing.T) {
ctx := context.Background()
ctr := 0

fn := func() error {
ctr++
if ctr == 2 {
return nil
}
return fmt.Errorf("fake error")
}

policy := Policy{MaxAttempts: 5, BackoffFactor: 2, Jitter: time.Millisecond}

if err := Run(ctx, policy, fn); err != nil {
t.Errorf("Retry(ctx, %+v, fn) failed unexpectedly, err: %+v", policy, err)
}

want := 2
if ctr != want {
t.Errorf("Retry(ctx, %+v, fn) retried %d times, should've returned after %d retries", policy, ctr, want)
}
}

func TestRetryError(t *testing.T) {
ctx := context.Background()
ctr := 0

fn := func() error {
ctr++
return fmt.Errorf("fake error")
}

policy := Policy{MaxAttempts: 4, BackoffFactor: 1, Jitter: time.Millisecond * 2}

if err := Run(ctx, policy, fn); err == nil {
t.Errorf("Retry(ctx, %+v, fn) succeded, want error", policy)
}

// Max retry attempts error.
if ctr != policy.MaxAttempts {
t.Errorf("Retry(ctx, %+v, fn) retried %d times, should've returned after %d retries", policy, ctr, policy.MaxAttempts)
}

// Zero attempts error.
zeroPolicy := Policy{MaxAttempts: 0, BackoffFactor: 1, Jitter: time.Millisecond * 2}

if err := Run(ctx, zeroPolicy, fn); err == nil {
t.Errorf("Retry(ctx, %+v, fn) succeded, want zero attempts error", zeroPolicy)
}

// Emtpy function error.
if err := Run(ctx, policy, nil); err == nil {
t.Errorf("Retry(ctx, %+v, nil) succeded, want nil function error", policy)
}

// Context cancelled error.
c, cancel := context.WithTimeout(ctx, time.Microsecond)
cancel()
if err := Run(c, policy, fn); err == nil {
t.Errorf("Retry(ctx, %+v, fn) succeded, want context error", policy)
}
}

func TestRetryWithResponse(t *testing.T) {
ctx := context.Background()
ctr := 0

fn := func() (int, error) {
ctr++
if ctr == 2 {
return ctr, nil
}
return -1, fmt.Errorf("fake error")
}

policy := Policy{MaxAttempts: 5, BackoffFactor: 1, Jitter: time.Millisecond}
want := 2
got, err := RunWithResponse(ctx, policy, fn)
if err != nil {
t.Errorf("RetryWithResponse(ctx, %+v, fn) failed unexpectedly, err: %+v", policy, err)
}
if got != want {
t.Errorf("RetryWithResponse(ctx, %+v, fn) = %d, want %d", policy, got, want)
}
if ctr != want {
t.Errorf("RetryWithResponse(ctx, %+v, fn) retried %d times, should've returned after %d retries", policy, ctr, want)
}
}

func TestBackoff(t *testing.T) {
tests := []struct {
name string
factor float64
attempts int
jitter time.Duration
want []time.Duration
}{
{
name: "constant_backoff",
factor: 1,
attempts: 5,
jitter: time.Duration(10),
want: []time.Duration{10, 10, 10, 10, 10},
},
{
name: "exponential_backoff_2",
factor: 2,
attempts: 4,
jitter: time.Duration(10),
want: []time.Duration{10, 20, 40, 80},
},
{
name: "exponential_backoff_3",
factor: 3,
attempts: 4,
jitter: time.Duration(10),
want: []time.Duration{10, 30, 90, 270},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
policy := Policy{MaxAttempts: tt.attempts, BackoffFactor: tt.factor, Jitter: tt.jitter}
for i := 0; i < tt.attempts; i++ {
if got := backoff(i, policy); got != tt.want[i] {
t.Errorf("backoff(%d, %+v) = %d, want %d", i, policy, got, tt.want[i])
}
}
})
}
}

func TestIsRetriable(t *testing.T) {
// Fake ShouldRetry() override.
f := func(err error) bool {
return !errors.Is(err, context.DeadlineExceeded)
}

tests := []struct {
name string
err error
policy Policy
want bool
}{
{
name: "no_override",
want: true,
},
{
name: "override_no_retry",
err: context.DeadlineExceeded,
policy: Policy{ShouldRetry: f},
want: false,
},
{
name: "override_retry",
err: fmt.Errorf("fake retriable error"),
policy: Policy{ShouldRetry: f},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isRetriable(tt.policy, tt.err); got != tt.want {
t.Errorf("isRetriable(%+v, %+v) = %t, want %t", tt.policy, tt.err, got, tt.want)
}
})
}
}

0 comments on commit 3e94187

Please sign in to comment.