@@ -17,13 +17,20 @@ package main
17
17
import (
18
18
"context"
19
19
"fmt"
20
+ "net/http"
21
+ "net/http/httptest"
20
22
"net/url"
21
23
"os"
24
+ "path/filepath"
22
25
"reflect"
26
+ "strings"
23
27
"testing"
28
+ "time"
24
29
30
+ "cloud.google.com/go/storage"
25
31
"github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg"
26
32
"github.com/GoogleCloudPlatform/guest-agent/metadata"
33
+ "google.golang.org/api/option"
27
34
)
28
35
29
36
func TestMain (m * testing.M ) {
@@ -297,3 +304,151 @@ func TestGetWantedKeysError(t *testing.T) {
297
304
})
298
305
}
299
306
}
307
+
308
+ func TestDownloadURL (t * testing.T ) {
309
+ ctx := context .Background ()
310
+ ctr := make (map [string ]int )
311
+ // No need to wait longer, override for testing.
312
+ defaultRetryPolicy .Jitter = time .Millisecond
313
+
314
+ server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
315
+ // /retry should succeed within 2 retries; /fail should always fail.
316
+ if (r .URL .Path == "/retry" && ctr ["/retry" ] != 1 ) || strings .Contains (r .URL .Path , "fail" ) {
317
+ w .WriteHeader (400 )
318
+ }
319
+
320
+ fmt .Fprintf (w , r .URL .Path )
321
+ ctr [r .URL .Path ] = ctr [r .URL .Path ] + 1
322
+ }))
323
+ defer server .Close ()
324
+
325
+ tests := []struct {
326
+ name string
327
+ key string
328
+ wantErr bool
329
+ retries int
330
+ }{
331
+ {
332
+ name : "succeed_immediately" ,
333
+ key : "/immediate_download" ,
334
+ wantErr : false ,
335
+ retries : 1 ,
336
+ },
337
+ {
338
+ name : "succeed_after_retry" ,
339
+ key : "/retry" ,
340
+ wantErr : false ,
341
+ retries : 2 ,
342
+ },
343
+ {
344
+ name : "fail_retry_exhaust" ,
345
+ key : "/fail" ,
346
+ wantErr : true ,
347
+ retries : 3 ,
348
+ },
349
+ }
350
+ for _ , tt := range tests {
351
+ t .Run (tt .name , func (t * testing.T ) {
352
+ f , err := os .OpenFile (filepath .Join (t .TempDir (), tt .name ), os .O_RDWR | os .O_CREATE | os .O_TRUNC , 0755 )
353
+ if err != nil {
354
+ t .Fatalf ("Failed to setup test file: %v" , err )
355
+ }
356
+ defer f .Close ()
357
+ url := server .URL + tt .key
358
+ if err := downloadURL (ctx , url , f ); (err != nil ) != tt .wantErr {
359
+ t .Errorf ("downloadURL(ctx, %s, %s) error = [%v], wantErr %t" , url , f .Name (), err , tt .wantErr )
360
+ }
361
+
362
+ if ! tt .wantErr {
363
+ gotBytes , err := os .ReadFile (f .Name ())
364
+ if err != nil {
365
+ t .Errorf ("failed to read output file %q, with error: %v" , f .Name (), err )
366
+ }
367
+ if string (gotBytes ) != tt .key {
368
+ t .Errorf ("downloadURL(ctx, %s, %s) wrote = [%s], want [%s]" , url , f .Name (), string (gotBytes ), tt .key )
369
+ }
370
+ }
371
+
372
+ if ctr [tt .key ] != tt .retries {
373
+ t .Errorf ("downloadURL(ctx, %s, %s) retried [%d] times, should have returned after [%d] retries" , url , f .Name (), ctr [tt .key ], tt .retries )
374
+ }
375
+ })
376
+ }
377
+ }
378
+
379
+ func TestDownloadGSURL (t * testing.T ) {
380
+ ctx := context .Background ()
381
+ ctr := make (map [string ]int )
382
+ // No need to wait longer, override for testing.
383
+ defaultRetryPolicy .Jitter = time .Millisecond
384
+ server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
385
+ // Fake error for invalid object request.
386
+ if strings .Contains (r .URL .Path , "invalid" ) {
387
+ w .WriteHeader (404 )
388
+ }
389
+ fmt .Fprintf (w , r .URL .Path )
390
+ ctr [r .URL .Path ] = ctr [r .URL .Path ] + 1
391
+ }))
392
+ defer server .Close ()
393
+
394
+ var err error
395
+ httpClient := & http.Client {Transport : & http.Transport {}}
396
+ testStorageClient , err = storage .NewClient (ctx , option .WithHTTPClient (httpClient ), option .WithEndpoint (server .URL ))
397
+ if err != nil {
398
+ t .Fatalf ("Failed to setup test storage client, err: %+v" , err )
399
+ }
400
+ defer testStorageClient .Close ()
401
+
402
+ tests := []struct {
403
+ name string
404
+ bucket string
405
+ object string
406
+ wantErr bool
407
+ retries int
408
+ }{
409
+ {
410
+ name : "valid_object" ,
411
+ bucket : "valid" ,
412
+ object : "obj1" ,
413
+ wantErr : false ,
414
+ retries : 1 ,
415
+ },
416
+ {
417
+ name : "invalid_object" ,
418
+ bucket : "invalid" ,
419
+ object : "obj1" ,
420
+ wantErr : true ,
421
+ retries : 3 ,
422
+ },
423
+ }
424
+ for _ , tt := range tests {
425
+ t .Run (tt .name , func (t * testing.T ) {
426
+ f , err := os .OpenFile (filepath .Join (t .TempDir (), tt .name ), os .O_RDWR | os .O_CREATE | os .O_TRUNC , 0755 )
427
+ if err != nil {
428
+ t .Fatalf ("Failed to setup test file: %v" , err )
429
+ }
430
+ defer f .Close ()
431
+
432
+ if err := downloadGSURL (ctx , tt .bucket , tt .object , f ); (err != nil ) != tt .wantErr {
433
+ t .Errorf ("downloadGSURL(ctx, %s, %s, %s) error = [%+v], wantErr %t" , tt .bucket , tt .object , f .Name (), err , tt .wantErr )
434
+ }
435
+
436
+ want := fmt .Sprintf ("/%s/%s" , tt .bucket , tt .object )
437
+
438
+ if ! tt .wantErr {
439
+ gotBytes , err := os .ReadFile (f .Name ())
440
+ if err != nil {
441
+ t .Errorf ("failed to read output file %q, with error: %v" , f .Name (), err )
442
+ }
443
+
444
+ if string (gotBytes ) != want {
445
+ t .Errorf ("downloadGSURL(ctx, %s, %s, %s) wrote = [%s], want [%s]" , tt .bucket , tt .object , f .Name (), string (gotBytes ), want )
446
+ }
447
+ }
448
+
449
+ if ctr [want ] != tt .retries {
450
+ t .Errorf ("downloadGSURL(ctx, %s, %s, %s) retried [%d] times, should have returned after [%d] retries" , tt .bucket , tt .object , f .Name (), ctr [want ], tt .retries )
451
+ }
452
+ })
453
+ }
454
+ }
0 commit comments