@@ -386,6 +386,217 @@ func TestWaiterError(t *testing.T) {
386
386
}
387
387
}
388
388
389
+ func TestWaiterRetryAnyError (t * testing.T ) {
390
+ svc := & mockClient {Client : awstesting .NewClient (& aws.Config {
391
+ Region : aws .String ("mock-region" ),
392
+ })}
393
+ svc .Handlers .Send .Clear () // mock sending
394
+ svc .Handlers .Unmarshal .Clear ()
395
+ svc .Handlers .UnmarshalMeta .Clear ()
396
+ svc .Handlers .UnmarshalError .Clear ()
397
+ svc .Handlers .ValidateResponse .Clear ()
398
+
399
+ var reqNum int
400
+ results := []struct {
401
+ Out * MockOutput
402
+ Err error
403
+ }{
404
+ { // retry
405
+ Err : awserr .New (
406
+ "MockException1" , "mock exception message" , nil ,
407
+ ),
408
+ },
409
+ { // retry
410
+ Err : awserr .New (
411
+ "MockException2" , "mock exception message" , nil ,
412
+ ),
413
+ },
414
+ { // success
415
+ Out : & MockOutput {
416
+ States : []* MockState {
417
+ {aws .String ("running" )},
418
+ {aws .String ("running" )},
419
+ },
420
+ },
421
+ },
422
+ { // shouldn't happen
423
+ Out : & MockOutput {
424
+ States : []* MockState {
425
+ {aws .String ("running" )},
426
+ {aws .String ("running" )},
427
+ },
428
+ },
429
+ },
430
+ }
431
+
432
+ numBuiltReq := 0
433
+ svc .Handlers .Build .PushBack (func (r * request.Request ) {
434
+ numBuiltReq ++
435
+ })
436
+ svc .Handlers .Send .PushBack (func (r * request.Request ) {
437
+ code := http .StatusOK
438
+ r .HTTPResponse = & http.Response {
439
+ StatusCode : code ,
440
+ Status : http .StatusText (code ),
441
+ Body : ioutil .NopCloser (bytes .NewReader ([]byte {})),
442
+ }
443
+ })
444
+ svc .Handlers .Unmarshal .PushBack (func (r * request.Request ) {
445
+ if reqNum >= len (results ) {
446
+ t .Errorf ("too many polling requests made" )
447
+ return
448
+ }
449
+ r .Data = results [reqNum ].Out
450
+ reqNum ++
451
+ })
452
+ svc .Handlers .UnmarshalMeta .PushBack (func (r * request.Request ) {
453
+ // If there was an error unmarshal error will be called instead of unmarshal
454
+ // need to increment count here also
455
+ if err := results [reqNum ].Err ; err != nil {
456
+ r .Error = err
457
+ reqNum ++
458
+ }
459
+ })
460
+
461
+ w := request.Waiter {
462
+ MaxAttempts : 10 ,
463
+ Delay : request .ConstantWaiterDelay (0 ),
464
+ SleepWithContext : aws .SleepWithContext ,
465
+ Acceptors : []request.WaiterAcceptor {
466
+ {
467
+ State : request .SuccessWaiterState ,
468
+ Matcher : request .PathAllWaiterMatch ,
469
+ Argument : "States[].State" ,
470
+ Expected : "running" ,
471
+ },
472
+ {
473
+ State : request .RetryWaiterState ,
474
+ Matcher : request .ErrorWaiterMatch ,
475
+ Argument : "" ,
476
+ Expected : true ,
477
+ },
478
+ {
479
+ State : request .FailureWaiterState ,
480
+ Matcher : request .ErrorWaiterMatch ,
481
+ Argument : "" ,
482
+ Expected : "FailureException" ,
483
+ },
484
+ },
485
+ NewRequest : BuildNewMockRequest (svc , & MockInput {}),
486
+ }
487
+
488
+ err := w .WaitWithContext (aws .BackgroundContext ())
489
+ if err != nil {
490
+ t .Fatalf ("expected no error, but did get one: %v" , err )
491
+ }
492
+ if e , a := 3 , numBuiltReq ; e != a {
493
+ t .Errorf ("expect %d built requests got %d" , e , a )
494
+ }
495
+ if e , a := 3 , reqNum ; e != a {
496
+ t .Errorf ("expect %d reqNum got %d" , e , a )
497
+ }
498
+ }
499
+
500
+ func TestWaiterSuccessNoError (t * testing.T ) {
501
+ svc := & mockClient {Client : awstesting .NewClient (& aws.Config {
502
+ Region : aws .String ("mock-region" ),
503
+ })}
504
+ svc .Handlers .Send .Clear () // mock sending
505
+ svc .Handlers .Unmarshal .Clear ()
506
+ svc .Handlers .UnmarshalMeta .Clear ()
507
+ svc .Handlers .UnmarshalError .Clear ()
508
+ svc .Handlers .ValidateResponse .Clear ()
509
+
510
+ var reqNum int
511
+ results := []struct {
512
+ Out * MockOutput
513
+ Err error
514
+ }{
515
+ { // success
516
+ Out : & MockOutput {
517
+ States : []* MockState {
518
+ {aws .String ("pending" )},
519
+ },
520
+ },
521
+ },
522
+ { // shouldn't happen
523
+ Out : & MockOutput {
524
+ States : []* MockState {
525
+ {aws .String ("pending" )},
526
+ {aws .String ("pending" )},
527
+ },
528
+ },
529
+ },
530
+ }
531
+
532
+ numBuiltReq := 0
533
+ svc .Handlers .Build .PushBack (func (r * request.Request ) {
534
+ numBuiltReq ++
535
+ })
536
+ svc .Handlers .Send .PushBack (func (r * request.Request ) {
537
+ code := http .StatusOK
538
+ r .HTTPResponse = & http.Response {
539
+ StatusCode : code ,
540
+ Status : http .StatusText (code ),
541
+ Body : ioutil .NopCloser (bytes .NewReader ([]byte {})),
542
+ }
543
+ })
544
+ svc .Handlers .Unmarshal .PushBack (func (r * request.Request ) {
545
+ if reqNum >= len (results ) {
546
+ t .Errorf ("too many polling requests made" )
547
+ return
548
+ }
549
+ r .Data = results [reqNum ].Out
550
+ reqNum ++
551
+ })
552
+ svc .Handlers .UnmarshalMeta .PushBack (func (r * request.Request ) {
553
+ // If there was an error unmarshal error will be called instead of unmarshal
554
+ // need to increment count here also
555
+ if err := results [reqNum ].Err ; err != nil {
556
+ r .Error = err
557
+ reqNum ++
558
+ }
559
+ })
560
+
561
+ w := request.Waiter {
562
+ MaxAttempts : 10 ,
563
+ Delay : request .ConstantWaiterDelay (0 ),
564
+ SleepWithContext : aws .SleepWithContext ,
565
+ Acceptors : []request.WaiterAcceptor {
566
+ {
567
+ State : request .SuccessWaiterState ,
568
+ Matcher : request .PathAllWaiterMatch ,
569
+ Argument : "States[].State" ,
570
+ Expected : "running" ,
571
+ },
572
+ {
573
+ State : request .SuccessWaiterState ,
574
+ Matcher : request .ErrorWaiterMatch ,
575
+ Argument : "" ,
576
+ Expected : false ,
577
+ },
578
+ {
579
+ State : request .FailureWaiterState ,
580
+ Matcher : request .ErrorWaiterMatch ,
581
+ Argument : "" ,
582
+ Expected : "FailureException" ,
583
+ },
584
+ },
585
+ NewRequest : BuildNewMockRequest (svc , & MockInput {}),
586
+ }
587
+
588
+ err := w .WaitWithContext (aws .BackgroundContext ())
589
+ if err != nil {
590
+ t .Fatalf ("expected no error, but did get one" )
591
+ }
592
+ if e , a := 1 , numBuiltReq ; e != a {
593
+ t .Errorf ("expect %d built requests got %d" , e , a )
594
+ }
595
+ if e , a := 1 , reqNum ; e != a {
596
+ t .Errorf ("expect %d reqNum got %d" , e , a )
597
+ }
598
+ }
599
+
389
600
func TestWaiterStatus (t * testing.T ) {
390
601
svc := & mockClient {Client : awstesting .NewClient (& aws.Config {
391
602
Region : aws .String ("mock-region" ),
0 commit comments