diff --git a/.travis.yml b/.travis.yml index 2ea99ac..eb8b7b9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ language: go go: - - 1.13 + - 1.15 before_install: - go get github.com/mattn/goveralls - go get golang.org/x/tools/cmd/cover diff --git a/cmd/update-aws-ecs-service/main.go b/cmd/update-aws-ecs-service/main.go index 9aa9fad..8ecd4cd 100644 --- a/cmd/update-aws-ecs-service/main.go +++ b/cmd/update-aws-ecs-service/main.go @@ -60,6 +60,7 @@ func main() { region := flag.String("region", "", "region name") taskdef := flag.String("taskdef", "", "base task definition (instead of current)") desiredCount := flag.Int64("desired-count", -1, "desired-count (negative: no change)") + taskrole := flag.String("task-role", "", fmt.Sprintf(`task iam role, set to "%s" to clear`, awsecs.TaskRoleKnockoutValue)) var images mapFlag = map[string]string{} var envs mapMapFlag = map[string]map[string]string{} @@ -91,6 +92,7 @@ func main() { Secrets: secrets, LogDriverOptions: logopts, LogDriverSecrets: logsecrets, + TaskRole: *taskrole, DesiredCount: int64ptr(*desiredCount), Taskdef: *taskdef, BackOff: backoff.NewExponentialBackOff(), diff --git a/ecs-alter-service.go b/ecs-alter-service.go index 6c84df0..83a0773 100644 --- a/ecs-alter-service.go +++ b/ecs-alter-service.go @@ -19,8 +19,8 @@ var ( ErrFailedRollback = errors.New("failed rollback") ) -func alterServiceOrValidatedRollBack(api ecs.ECS, cluster, service string, imageMap map[string]string, envMaps map[string]map[string]string, secretMaps map[string]map[string]string, logopts map[string]map[string]map[string]string, logsecrets map[string]map[string]map[string]string, desiredCount *int64, taskdef string, bo backoff.BackOff) error { - oldsvc, alterSvcErr := alterServiceValidateDeployment(api, cluster, service, imageMap, envMaps, secretMaps, logopts, logsecrets, desiredCount, taskdef, bo) +func alterServiceOrValidatedRollBack(api ecs.ECS, cluster, service string, imageMap map[string]string, envMaps map[string]map[string]string, secretMaps map[string]map[string]string, logopts map[string]map[string]map[string]string, logsecrets map[string]map[string]map[string]string, taskRole string, desiredCount *int64, taskdef string, bo backoff.BackOff) error { + oldsvc, alterSvcErr := alterServiceValidateDeployment(api, cluster, service, imageMap, envMaps, secretMaps, logopts, logsecrets, taskRole, desiredCount, taskdef, bo) if alterSvcErr != nil { operation := func() error { if oldsvc.ServiceName == nil { diff --git a/ecs.go b/ecs.go index 25d225f..54801d7 100644 --- a/ecs.go +++ b/ecs.go @@ -1,7 +1,6 @@ package awsecs import ( - "encoding/json" "errors" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ecs" @@ -30,21 +29,16 @@ var ( ) func copyTd(input ecs.TaskDefinition, tags []*ecs.Tag) ecs.RegisterTaskDefinitionInput { - obj, err := json.Marshal(input) - if err != nil { - panic(err) - } + obj := panicMarshal(input) inputClone := ecs.TaskDefinition{} - err = json.Unmarshal(obj, &inputClone) - if err != nil { - panic(err) - } + panicUnmarshal(obj, &inputClone) output := ecs.RegisterTaskDefinitionInput{} + // TODO: replace with reflection output.ContainerDefinitions = inputClone.ContainerDefinitions output.Cpu = inputClone.Cpu output.ExecutionRoleArn = inputClone.ExecutionRoleArn output.Family = inputClone.Family - // output.InferenceAccelerators // not supported by the current version of the SDK + output.InferenceAccelerators = inputClone.InferenceAccelerators output.IpcMode = inputClone.IpcMode output.Memory = inputClone.Memory output.NetworkMode = inputClone.NetworkMode @@ -52,22 +46,17 @@ func copyTd(input ecs.TaskDefinition, tags []*ecs.Tag) ecs.RegisterTaskDefinitio output.PlacementConstraints = inputClone.PlacementConstraints output.ProxyConfiguration = inputClone.ProxyConfiguration output.RequiresCompatibilities = inputClone.RequiresCompatibilities - output.Tags = tags output.TaskRoleArn = inputClone.TaskRoleArn output.Volumes = inputClone.Volumes + // can't be replaced with reflection + output.Tags = tags return output } func alterImages(copy ecs.RegisterTaskDefinitionInput, imageMap map[string]string) ecs.RegisterTaskDefinitionInput { - obj, err := json.Marshal(copy) - if err != nil { - panic(err) - } + obj := panicMarshal(copy) copyClone := ecs.RegisterTaskDefinitionInput{} - err = json.Unmarshal(obj, ©Clone) - if err != nil { - panic(err) - } + panicUnmarshal(obj, ©Clone) for name, image := range imageMap { for _, containerDefinition := range copyClone.ContainerDefinitions { if *containerDefinition.Name == name { @@ -79,15 +68,9 @@ func alterImages(copy ecs.RegisterTaskDefinitionInput, imageMap map[string]strin } func alterEnvironments(copy ecs.RegisterTaskDefinitionInput, envMaps map[string]map[string]string) ecs.RegisterTaskDefinitionInput { - obj, err := json.Marshal(copy) - if err != nil { - panic(err) - } + obj := panicMarshal(copy) copyClone := ecs.RegisterTaskDefinitionInput{} - err = json.Unmarshal(obj, ©Clone) - if err != nil { - panic(err) - } + panicUnmarshal(obj, ©Clone) for name, envMap := range envMaps { for i, containerDefinition := range copyClone.ContainerDefinitions { if *containerDefinition.Name == name { @@ -100,15 +83,9 @@ func alterEnvironments(copy ecs.RegisterTaskDefinitionInput, envMaps map[string] } func alterSecrets(copy ecs.RegisterTaskDefinitionInput, secretMaps map[string]map[string]string) ecs.RegisterTaskDefinitionInput { - obj, err := json.Marshal(copy) - if err != nil { - panic(err) - } + obj := panicMarshal(copy) copyClone := ecs.RegisterTaskDefinitionInput{} - err = json.Unmarshal(obj, ©Clone) - if err != nil { - panic(err) - } + panicUnmarshal(obj, ©Clone) for name, secretMap := range secretMaps { for i, containerDefinition := range copyClone.ContainerDefinitions { if *containerDefinition.Name == name { @@ -166,14 +143,18 @@ func alterSecret(copy ecs.ContainerDefinition, secretMap map[string]string) ecs. return copy } -func copyTaskDef(api ecs.ECS, taskdef string, imageMap map[string]string, envMaps map[string]map[string]string, secretMaps map[string]map[string]string, logopts map[string]map[string]map[string]string, logsecrets map[string]map[string]map[string]string) (string, error) { +func copyTaskDef(api ecs.ECS, taskdef string, imageMap map[string]string, envMaps map[string]map[string]string, secretMaps map[string]map[string]string, logopts map[string]map[string]map[string]string, logsecrets map[string]map[string]map[string]string, taskRole string) (string, error) { output, err := api.DescribeTaskDefinition(&ecs.DescribeTaskDefinitionInput{TaskDefinition: aws.String(taskdef)}) if err != nil { return "", err } asRegisterTaskDefinitionInput := copyTd(*output.TaskDefinition, output.Tags) - tdCopy := alterLogConfigurations(alterSecrets(alterEnvironments(alterImages(asRegisterTaskDefinitionInput, imageMap), envMaps), secretMaps), logopts, logsecrets) + tdCopy := alterImages(asRegisterTaskDefinitionInput, imageMap) + tdCopy = alterEnvironments(tdCopy, envMaps) + tdCopy = alterSecrets(tdCopy, secretMaps) + tdCopy = alterLogConfigurations(tdCopy, logopts, logsecrets) + tdCopy = alterTaskRole(tdCopy, taskRole) if reflect.DeepEqual(asRegisterTaskDefinitionInput, tdCopy) { return *output.TaskDefinition.TaskDefinitionArn, nil @@ -186,7 +167,7 @@ func copyTaskDef(api ecs.ECS, taskdef string, imageMap map[string]string, envMap return *arn, nil } -func alterService(api ecs.ECS, cluster, service string, imageMap map[string]string, envMaps map[string]map[string]string, secretMaps map[string]map[string]string, logopts map[string]map[string]map[string]string, logsecrets map[string]map[string]map[string]string, desiredCount *int64, taskdef string) (ecs.Service, ecs.Service, error) { +func alterService(api ecs.ECS, cluster, service string, imageMap map[string]string, envMaps map[string]map[string]string, secretMaps map[string]map[string]string, logopts map[string]map[string]map[string]string, logsecrets map[string]map[string]map[string]string, taskRole string, desiredCount *int64, taskdef string) (ecs.Service, ecs.Service, error) { output, err := api.DescribeServices(&ecs.DescribeServicesInput{Cluster: aws.String(cluster), Services: []*string{aws.String(service)}}) if err != nil { return ecs.Service{}, ecs.Service{}, err @@ -196,7 +177,7 @@ func alterService(api ecs.ECS, cluster, service string, imageMap map[string]stri if taskdef != "" { srcTaskDef = &taskdef } - newTd, err := copyTaskDef(api, *srcTaskDef, imageMap, envMaps, secretMaps, logopts, logsecrets) + newTd, err := copyTaskDef(api, *srcTaskDef, imageMap, envMaps, secretMaps, logopts, logsecrets, taskRole) if err != nil { return *svc, ecs.Service{}, err } @@ -264,8 +245,8 @@ func validateDeployment(api ecs.ECS, ecsService ecs.Service, bo backoff.BackOff) return errNoPrimaryDeployment } -func alterServiceValidateDeployment(api ecs.ECS, cluster, service string, imageMap map[string]string, envMaps map[string]map[string]string, secretMaps map[string]map[string]string, logopts map[string]map[string]map[string]string, logsecrets map[string]map[string]map[string]string, desiredCount *int64, taskdef string, bo backoff.BackOff) (ecs.Service, error) { - oldsvc, newsvc, err := alterService(api, cluster, service, imageMap, envMaps, secretMaps, logopts, logsecrets, desiredCount, taskdef) +func alterServiceValidateDeployment(api ecs.ECS, cluster, service string, imageMap map[string]string, envMaps map[string]map[string]string, secretMaps map[string]map[string]string, logopts map[string]map[string]map[string]string, logsecrets map[string]map[string]map[string]string, taskRole string, desiredCount *int64, taskdef string, bo backoff.BackOff) (ecs.Service, error) { + oldsvc, newsvc, err := alterService(api, cluster, service, imageMap, envMaps, secretMaps, logopts, logsecrets, taskRole, desiredCount, taskdef) if err != nil { return oldsvc, err } @@ -291,6 +272,7 @@ type ECSServiceUpdate struct { Secrets map[string]map[string]string // Map of container names environment variable name and valueFrom LogDriverOptions map[string]map[string]map[string]string // Map of container names log driver name log driver option and value LogDriverSecrets map[string]map[string]map[string]string // Map of container names log driver name log driver secret and valueFrom + TaskRole string // Task IAM Role if TaskRoleKnockoutValue used, it is cleared DesiredCount *int64 // If nil the service desired count is not altered BackOff backoff.BackOff // BackOff strategy to use when validating the update Taskdef string // If non empty used as base task definition instead of the current task definition @@ -298,5 +280,5 @@ type ECSServiceUpdate struct { // Apply the ECS Service Update func (e *ECSServiceUpdate) Apply() error { - return alterServiceOrValidatedRollBack(e.API, e.Cluster, e.Service, e.Image, e.Environment, e.Secrets, e.LogDriverOptions, e.LogDriverSecrets, e.DesiredCount, e.Taskdef, e.BackOff) + return alterServiceOrValidatedRollBack(e.API, e.Cluster, e.Service, e.Image, e.Environment, e.Secrets, e.LogDriverOptions, e.LogDriverSecrets, e.TaskRole, e.DesiredCount, e.Taskdef, e.BackOff) } diff --git a/ecs_logconfig.go b/ecs_logconfig.go index 8e4e20e..a86e870 100644 --- a/ecs_logconfig.go +++ b/ecs_logconfig.go @@ -1,7 +1,6 @@ package awsecs import ( - "encoding/json" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ecs" ) @@ -112,15 +111,9 @@ func alterLogConfigurationLogDriverSecrets(copy ecs.LogConfiguration, overrides } func alterLogConfigurations(copy ecs.RegisterTaskDefinitionInput, containersOptions map[string]map[string]map[string]string, containersSecrets map[string]map[string]map[string]string) ecs.RegisterTaskDefinitionInput { - obj, err := json.Marshal(copy) - if err != nil { - panic(err) - } + obj := panicMarshal(copy) copyClone := ecs.RegisterTaskDefinitionInput{} - err = json.Unmarshal(obj, ©Clone) - if err != nil { - panic(err) - } + panicUnmarshal(obj, ©Clone) for _, containerDefinition := range copyClone.ContainerDefinitions { for containerName, containerOptions := range containersOptions { if containerDefinition.Name != nil && *containerDefinition.Name == containerName { diff --git a/ecs_roleconfig.go b/ecs_roleconfig.go new file mode 100644 index 0000000..e69fab9 --- /dev/null +++ b/ecs_roleconfig.go @@ -0,0 +1,21 @@ +package awsecs + +import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ecs" +) + +const TaskRoleKnockoutValue = "None" + +func alterTaskRole(copy ecs.RegisterTaskDefinitionInput, taskRoleArn string) ecs.RegisterTaskDefinitionInput { + obj := panicMarshal(copy) + copyClone := ecs.RegisterTaskDefinitionInput{} + panicUnmarshal(obj, ©Clone) + if taskRoleArn != "" { + copyClone.TaskRoleArn = aws.String(taskRoleArn) + } + if taskRoleArn == TaskRoleKnockoutValue { + copyClone.TaskRoleArn = nil + } + return copyClone +} diff --git a/ecs_roleconfig_test.go b/ecs_roleconfig_test.go new file mode 100644 index 0000000..ca022bf --- /dev/null +++ b/ecs_roleconfig_test.go @@ -0,0 +1,59 @@ +package awsecs + +import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ecs" + "reflect" + "testing" +) + +func TestAlterTaskRole(t *testing.T) { + type args struct { + copy ecs.RegisterTaskDefinitionInput + taskRoleArn string + } + tests := []struct { + name string + args args + want ecs.RegisterTaskDefinitionInput + }{ + { + name: "None test", + args: args{ + ecs.RegisterTaskDefinitionInput{ + TaskRoleArn: aws.String("something")}, + "None", + }, + want: ecs.RegisterTaskDefinitionInput{}, + }, + { + name: "Set value test", + args: args{ + ecs.RegisterTaskDefinitionInput{}, + "taskRoleArn", + }, + want: ecs.RegisterTaskDefinitionInput{ + TaskRoleArn: aws.String("taskRoleArn"), + }, + }, + { + name: "Keep value test", + args: args{ + ecs.RegisterTaskDefinitionInput{ + TaskRoleArn: aws.String("keepTaskRoleArn"), + }, + "", + }, + want: ecs.RegisterTaskDefinitionInput{ + TaskRoleArn: aws.String("keepTaskRoleArn"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := alterTaskRole(tt.args.copy, tt.args.taskRoleArn); !reflect.DeepEqual(got, tt.want) { + t.Errorf("alterTaskRole() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/go.mod b/go.mod index 2034c16..27f8d77 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,11 @@ module github.com/Autodesk/go-awsecs -go 1.13 +go 1.15 require ( - github.com/aws/aws-sdk-go v1.20.21 + github.com/aws/aws-sdk-go v1.35.7 github.com/cenkalti/backoff v0.0.0-00010101000000-000000000000 - github.com/sergi/go-diff v1.0.0 - github.com/stretchr/testify v1.5.1 // indirect - golang.org/x/net v0.0.0-20200219183655-46282727080f // indirect + github.com/sergi/go-diff v1.1.0 ) -replace github.com/cenkalti/backoff => github.com/cenkalti/backoff/v3 v3.1.1 +replace github.com/cenkalti/backoff => github.com/cenkalti/backoff/v4 v4.1.0 diff --git a/go.sum b/go.sum index 27d7bca..c3c8246 100644 --- a/go.sum +++ b/go.sum @@ -1,19 +1,35 @@ github.com/aws/aws-sdk-go v1.20.21 h1:22vHWL9rur+SRTYPHAXlxJMFIA9OSYsYDIAHFDhQ7Z0= github.com/aws/aws-sdk-go v1.20.21/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= +github.com/aws/aws-sdk-go v1.35.7 h1:FHMhVhyc/9jljgFAcGkQDYjpC9btM0B8VfkLBfctdNE= +github.com/aws/aws-sdk-go v1.35.7/go.mod h1:tlPOdRjfxPBpNIwqDj61rmsnA85v9jc0Ps9+muhnW+k= github.com/cenkalti/backoff/v3 v3.1.1 h1:UBHElAnr3ODEbpqPzX8g5sBcASjoLFtt3L/xwJ01L6E= github.com/cenkalti/backoff/v3 v3.1.1/go.mod h1:cIeZDE3IrqwwJl6VUwCN6trj1oXrTS4rc0ij+ULvLYs= +github.com/cenkalti/backoff/v4 v4.1.0 h1:c8LkOFQTzuO0WBM/ae5HdGQuZPfPxp7lqBRwQRm4fSc= +github.com/cenkalti/backoff/v4 v4.1.0/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= +github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200219183655-46282727080f h1:dB42wwhNuwPvh8f+5zZWNcU+F2Xs/B9wXXwvUCOH7r8= golang.org/x/net v0.0.0-20200219183655-46282727080f/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -21,5 +37,8 @@ golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/json.go b/json.go new file mode 100644 index 0000000..20024ad --- /dev/null +++ b/json.go @@ -0,0 +1,18 @@ +package awsecs + +import "encoding/json" + +func panicMarshal(v interface{}) (out []byte) { + out, err := json.Marshal(v) + if err != nil { + panic(err) + } + return +} + +func panicUnmarshal(data []byte, v interface{}) { + err := json.Unmarshal(data, v) + if err != nil { + panic(err) + } +} diff --git a/json_test.go b/json_test.go new file mode 100644 index 0000000..d48f6d1 --- /dev/null +++ b/json_test.go @@ -0,0 +1,93 @@ +package awsecs + +import ( + "fmt" + "reflect" + "testing" +) + +func TestPanicMarshal(t *testing.T) { + type args struct { + v interface{} + } + funcMap := map[string]func(){} + funcMap["self"] = func() {} + tests := []struct { + name string + args args + wantOut []byte + }{ + { + name: "marshal", + args: args{ + map[string]string{ + "foo": "bar", + }, + }, + wantOut: []byte(`{"foo":"bar"}`), + }, + { + name: "panic marshal", + args: args{funcMap}, + wantOut: []byte("ignored"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if tt.name == "panic marshal" { + recoverTxt := fmt.Sprint(recover()) + if recoverTxt != "json: unsupported type: func()" { + t.Error(recoverTxt) + } + } + }() + if gotOut := panicMarshal(tt.args.v); !reflect.DeepEqual(gotOut, tt.wantOut) { + t.Errorf("panicMarshal() = %v, want %v", gotOut, tt.wantOut) + } + }) + } +} + +func TestPanicUnmarshal(t *testing.T) { + type args struct { + data []byte + v interface{} + } + simpleMap := map[string]string{} + tests := []struct { + name string + args args + }{ + { + name: "unmarshal", + args: args{ + []byte(`{"foo":"bar"}`), + &simpleMap, + }, + }, + { + name: "panic unmarshal", + args: args{ + []byte("BAD JSON"), + simpleMap, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if tt.name == "panic unmarshal" { + recoverTxt := fmt.Sprint(recover()) + if recoverTxt != "invalid character 'B' looking for beginning of value" { + t.Error(recoverTxt) + } + } + }() + panicUnmarshal(tt.args.data, &tt.args.v) + }) + } + if simpleMap["foo"] != "bar" { + t.Error(simpleMap["foo"]) + } +}