Skip to content

Commit

Permalink
Update metadata script runner, add tests (#357)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaitanyaKulkarni28 authored Feb 16, 2024
1 parent a13f564 commit 22eb487
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 37 deletions.
69 changes: 32 additions & 37 deletions google_metadata_script_runner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package main

// TODO: compare log outputs in this utility to linux.
// TODO: standardize and consolidate retries.

import (
"bufio"
Expand All @@ -40,6 +39,7 @@ import (
"cloud.google.com/go/storage"
"github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg"
"github.com/GoogleCloudPlatform/guest-agent/metadata"
"github.com/GoogleCloudPlatform/guest-agent/retry"
"github.com/GoogleCloudPlatform/guest-agent/utils"
"github.com/GoogleCloudPlatform/guest-logging-go/logger"
)
Expand Down Expand Up @@ -79,10 +79,13 @@ var (
// https://commondatastorage.googleapis.com/<bucket>/<object>
gsHTTPRegex3 = regexp.MustCompile(fmt.Sprintf(`^http[s]?://(?:commondata)?storage\.googleapis\.com/%s/%s$`, bucket, object))

// testStorageClient is used to override GCS client in unit tests.
testStorageClient *storage.Client

client metadata.MDSClientInterface
version string
// defaultRetryPolicy is default policy to retry up to 3 times, only wait 1 second between retries.
defaultRetryPolicy = retry.Policy{MaxAttempts: 3, BackoffFactor: 1, Jitter: time.Second}
)

func init() {
Expand All @@ -103,36 +106,35 @@ func downloadGSURL(ctx context.Context, bucket, object string, file *os.File) er
}
defer client.Close()

r, err := client.Bucket(bucket).Object(object).NewReader(ctx)
r, err := retry.RunWithResponse(ctx, defaultRetryPolicy, func() (*storage.Reader, error) {
r, err := client.Bucket(bucket).Object(object).NewReader(ctx)
return r, err
})
if err != nil {
return fmt.Errorf("error reading object %q: %v", object, err)
return err
}
defer r.Close()

_, err = io.Copy(file, r)
return err
}

func downloadURL(url string, file *os.File) error {
// Retry up to 3 times, only wait 1 second between retries.
var res *http.Response
var err error
for i := 1; ; i++ {
res, err = http.Get(url)
if err != nil && i > 3 {
return err
func downloadURL(ctx context.Context, url string, file *os.File) error {
res, err := retry.RunWithResponse(context.Background(), defaultRetryPolicy, func() (*http.Response, error) {
res, err := http.Get(url)
if err != nil {
return res, err
}
if err == nil {
break
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GET %q, bad status: %s", url, res.Status)
}
time.Sleep(1 * time.Second)
return res, nil
})
if err != nil {
return err
}
defer res.Body.Close()

if res.StatusCode != http.StatusOK {
return fmt.Errorf("GET %q, bad status: %s", url, res.Status)
}

_, err = io.Copy(file, res.Body)
return err
}
Expand All @@ -142,34 +144,27 @@ func downloadScript(ctx context.Context, path string, file *os.File) error {
// particularly once a system is promoted to a domain controller.
// Try to lookup storage.googleapis.com and sleep for up to 100s if
// we get an error.
// TODO: do we need to do this on every script?
for i := 0; i < 20; i++ {
if _, err := net.LookupHost(storageURL); err == nil {
break
}
time.Sleep(5 * time.Second)
policy := retry.Policy{MaxAttempts: 20, BackoffFactor: 1, Jitter: time.Second * 5}
err := retry.Run(ctx, policy, func() error {
_, err := net.LookupHost(storageURL)
return err
})
if err != nil {
return fmt.Errorf("%q lookup failed, err: %+v", storageURL, err)
}

bucket, object := parseGCS(path)
if bucket != "" && object != "" {
// TODO: why is this retry outer, but downloadURL retry is inner?
// Retry up to 3 times, only wait 1 second between retries.
for i := 1; ; i++ {
err := downloadGSURL(ctx, bucket, object, file)
if err == nil {
return nil
}
if err != nil && i > 3 {
logger.Infof("Failed to download GCS path: %v", err)
break
}
time.Sleep(1 * time.Second)
err = downloadGSURL(ctx, bucket, object, file)
if err != nil {
logger.Infof("Failed to download object [%s] from GCS bucket [%s], err: %+v", object, bucket, err)
}
logger.Infof("Trying unauthenticated download")
path = fmt.Sprintf("https://%s/%s/%s", storageURL, bucket, object)
}

// Fall back to an HTTP GET of the URL.
return downloadURL(path, file)
return downloadURL(ctx, path, file)
}

func parseGCS(path string) (string, string) {
Expand Down
155 changes: 155 additions & 0 deletions google_metadata_script_runner/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,20 @@ package main
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"time"

"cloud.google.com/go/storage"
"github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg"
"github.com/GoogleCloudPlatform/guest-agent/metadata"
"google.golang.org/api/option"
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -297,3 +304,151 @@ func TestGetWantedKeysError(t *testing.T) {
})
}
}

func TestDownloadURL(t *testing.T) {
ctx := context.Background()
ctr := make(map[string]int)
// No need to wait longer, override for testing.
defaultRetryPolicy.Jitter = time.Millisecond

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// /retry should succeed within 2 retries; /fail should always fail.
if (r.URL.Path == "/retry" && ctr["/retry"] != 1) || strings.Contains(r.URL.Path, "fail") {
w.WriteHeader(400)
}

fmt.Fprintf(w, r.URL.Path)
ctr[r.URL.Path] = ctr[r.URL.Path] + 1
}))
defer server.Close()

tests := []struct {
name string
key string
wantErr bool
retries int
}{
{
name: "succeed_immediately",
key: "/immediate_download",
wantErr: false,
retries: 1,
},
{
name: "succeed_after_retry",
key: "/retry",
wantErr: false,
retries: 2,
},
{
name: "fail_retry_exhaust",
key: "/fail",
wantErr: true,
retries: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f, err := os.OpenFile(filepath.Join(t.TempDir(), tt.name), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0755)
if err != nil {
t.Fatalf("Failed to setup test file: %v", err)
}
defer f.Close()
url := server.URL + tt.key
if err := downloadURL(ctx, url, f); (err != nil) != tt.wantErr {
t.Errorf("downloadURL(ctx, %s, %s) error = [%v], wantErr %t", url, f.Name(), err, tt.wantErr)
}

if !tt.wantErr {
gotBytes, err := os.ReadFile(f.Name())
if err != nil {
t.Errorf("failed to read output file %q, with error: %v", f.Name(), err)
}
if string(gotBytes) != tt.key {
t.Errorf("downloadURL(ctx, %s, %s) wrote = [%s], want [%s]", url, f.Name(), string(gotBytes), tt.key)
}
}

if ctr[tt.key] != tt.retries {
t.Errorf("downloadURL(ctx, %s, %s) retried [%d] times, should have returned after [%d] retries", url, f.Name(), ctr[tt.key], tt.retries)
}
})
}
}

func TestDownloadGSURL(t *testing.T) {
ctx := context.Background()
ctr := make(map[string]int)
// No need to wait longer, override for testing.
defaultRetryPolicy.Jitter = time.Millisecond
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Fake error for invalid object request.
if strings.Contains(r.URL.Path, "invalid") {
w.WriteHeader(404)
}
fmt.Fprintf(w, r.URL.Path)
ctr[r.URL.Path] = ctr[r.URL.Path] + 1
}))
defer server.Close()

var err error
httpClient := &http.Client{Transport: &http.Transport{}}
testStorageClient, err = storage.NewClient(ctx, option.WithHTTPClient(httpClient), option.WithEndpoint(server.URL))
if err != nil {
t.Fatalf("Failed to setup test storage client, err: %+v", err)
}
defer testStorageClient.Close()

tests := []struct {
name string
bucket string
object string
wantErr bool
retries int
}{
{
name: "valid_object",
bucket: "valid",
object: "obj1",
wantErr: false,
retries: 1,
},
{
name: "invalid_object",
bucket: "invalid",
object: "obj1",
wantErr: true,
retries: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f, err := os.OpenFile(filepath.Join(t.TempDir(), tt.name), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0755)
if err != nil {
t.Fatalf("Failed to setup test file: %v", err)
}
defer f.Close()

if err := downloadGSURL(ctx, tt.bucket, tt.object, f); (err != nil) != tt.wantErr {
t.Errorf("downloadGSURL(ctx, %s, %s, %s) error = [%+v], wantErr %t", tt.bucket, tt.object, f.Name(), err, tt.wantErr)
}

want := fmt.Sprintf("/%s/%s", tt.bucket, tt.object)

if !tt.wantErr {
gotBytes, err := os.ReadFile(f.Name())
if err != nil {
t.Errorf("failed to read output file %q, with error: %v", f.Name(), err)
}

if string(gotBytes) != want {
t.Errorf("downloadGSURL(ctx, %s, %s, %s) wrote = [%s], want [%s]", tt.bucket, tt.object, f.Name(), string(gotBytes), want)
}
}

if ctr[want] != tt.retries {
t.Errorf("downloadGSURL(ctx, %s, %s, %s) retried [%d] times, should have returned after [%d] retries", tt.bucket, tt.object, f.Name(), ctr[want], tt.retries)
}
})
}
}

0 comments on commit 22eb487

Please sign in to comment.