diff --git a/pkg/iac/scanners/terraform/module_test.go b/pkg/iac/scanners/terraform/module_test.go index d0d289a9d562..ec4291fd59e4 100644 --- a/pkg/iac/scanners/terraform/module_test.go +++ b/pkg/iac/scanners/terraform/module_test.go @@ -567,7 +567,7 @@ resource "something" "else" { for_each = toset(["true"]) content { - ok = each.value + ok = blah.value } } } diff --git a/pkg/iac/scanners/terraform/parser/evaluator.go b/pkg/iac/scanners/terraform/parser/evaluator.go index 809498f35c63..8e2e737b5d9b 100644 --- a/pkg/iac/scanners/terraform/parser/evaluator.go +++ b/pkg/iac/scanners/terraform/parser/evaluator.go @@ -260,36 +260,17 @@ func (e *evaluator) evaluateSteps() { } func (e *evaluator) expandBlocks(blocks terraform.Blocks) terraform.Blocks { - return e.expandDynamicBlocks(e.expandBlockForEaches(e.expandBlockCounts(blocks), false)...) + return e.expandDynamicBlocks(e.expandBlockForEaches(e.expandBlockCounts(blocks))...) } func (e *evaluator) expandDynamicBlocks(blocks ...*terraform.Block) terraform.Blocks { for _, b := range blocks { - e.expandDynamicBlock(b) - } - return blocks -} - -func (e *evaluator) expandDynamicBlock(b *terraform.Block) { - for _, sub := range b.AllBlocks() { - e.expandDynamicBlock(sub) - } - for _, sub := range b.AllBlocks().OfType("dynamic") { - if sub.IsExpanded() { - continue - } - blockName := sub.TypeLabel() - expanded := e.expandBlockForEaches(terraform.Blocks{sub}, true) - for _, ex := range expanded { - if content := ex.GetBlock("content"); content.IsNotNil() { - _ = e.expandDynamicBlocks(content) - b.InjectBlock(content, blockName) - } - } - if len(expanded) > 0 { - sub.MarkExpanded() + if err := b.ExpandBlock(); err != nil { + e.logger.Error(`Failed to expand dynamic block.`, + log.String("block", b.FullName()), log.Err(err)) } } + return blocks } func isBlockSupportsForEachMetaArgument(block *terraform.Block) bool { @@ -297,11 +278,10 @@ func isBlockSupportsForEachMetaArgument(block *terraform.Block) bool { "module", "resource", "data", - "dynamic", }, block.Type()) } -func (e *evaluator) expandBlockForEaches(blocks terraform.Blocks, isDynamic bool) terraform.Blocks { +func (e *evaluator) expandBlockForEaches(blocks terraform.Blocks) terraform.Blocks { var forEachFiltered terraform.Blocks @@ -348,7 +328,7 @@ func (e *evaluator) expandBlockForEaches(blocks terraform.Blocks, isDynamic bool // is the value of the collection. The exception is the use of for-each inside a dynamic block, // because in this case the collection element may not be a primitive value. if (forEachVal.Type().IsCollectionType() || forEachVal.Type().IsTupleType()) && - !forEachVal.Type().IsMapType() && !isDynamic { + !forEachVal.Type().IsMapType() { stringVal, err := convert.Convert(val, cty.String) if err != nil { e.logger.Error( @@ -374,22 +354,7 @@ func (e *evaluator) expandBlockForEaches(blocks terraform.Blocks, isDynamic bool ctx.Set(eachObj, "each") ctx.Set(eachObj, block.TypeLabel()) - - if isDynamic { - if iterAttr := block.GetAttribute("iterator"); iterAttr.IsNotNil() { - refs := iterAttr.AllReferences() - if len(refs) == 1 { - ctx.Set(idx, refs[0].TypeLabel(), "key") - ctx.Set(val, refs[0].TypeLabel(), "value") - } else { - e.logger.Debug("Ignoring iterator attribute in dynamic block, expected one reference", - log.Int("refs", len(refs))) - } - } - } - forEachFiltered = append(forEachFiltered, clone) - clones[idx.AsString()] = clone.Values() }) diff --git a/pkg/iac/scanners/terraform/parser/parser_test.go b/pkg/iac/scanners/terraform/parser/parser_test.go index e3bd817748f6..540338afb57f 100644 --- a/pkg/iac/scanners/terraform/parser/parser_test.go +++ b/pkg/iac/scanners/terraform/parser/parser_test.go @@ -1367,139 +1367,284 @@ func TestCountMetaArgumentInModule(t *testing.T) { } func TestDynamicBlocks(t *testing.T) { - t.Run("arg is list of int", func(t *testing.T) { - modules := parse(t, map[string]string{ - "main.tf": ` -resource "aws_security_group" "sg-webserver" { - vpc_id = "1111" - dynamic "ingress" { + tests := []struct { + name string + src string + expected []any + }{ + { + name: "for-each use tuple of int", + src: `resource "test_resource" "test" { + dynamic "foo" { for_each = [80, 443] content { - from_port = ingress.value - to_port = ingress.value - protocol = "tcp" - cidr_blocks = ["0.0.0.0/0"] + bar = foo.value } } -} -`, - }) - require.Len(t, modules, 1) - - secGroups := modules.GetResourcesByType("aws_security_group") - assert.Len(t, secGroups, 1) - ingressBlocks := secGroups[0].GetBlocks("ingress") - assert.Len(t, ingressBlocks, 2) - - var inboundPorts []int - for _, ingress := range ingressBlocks { - fromPort := ingress.GetAttribute("from_port").AsIntValueOrDefault(-1, ingress).Value() - inboundPorts = append(inboundPorts, fromPort) - } - - assert.True(t, compareSets([]int{80, 443}, inboundPorts)) - }) - - t.Run("empty for-each", func(t *testing.T) { - modules := parse(t, map[string]string{ - "main.tf": ` -resource "aws_lambda_function" "analyzer" { - dynamic "vpc_config" { +}`, + expected: []any{float64(80), float64(443)}, + }, + { + name: "for-each use list of int", + src: `resource "test_resource" "test" { + dynamic "foo" { + for_each = tolist([80, 443]) + content { + bar = foo.value + } + } +}`, + expected: []any{float64(80), float64(443)}, + }, + { + name: "for-each use set of int", + src: `resource "test_resource" "test" { + dynamic "foo" { + for_each = toset([80, 443]) + content { + bar = foo.value + } + } +}`, + expected: []any{float64(80), float64(443)}, + }, + { + name: "for-each use list of bool", + src: `resource "test_resource" "test" { + dynamic "foo" { + for_each = tolist([true]) + content { + bar = foo.value + } + } +}`, + expected: []any{true}, + }, + { + name: "empty for-each", + src: `resource "test_resource" "test" { + dynamic "foo" { for_each = [] content {} } +}`, + expected: []any{}, + }, + { + name: "for-each use tuple of objects", + src: `variable "test_var" { + type = list(object({ enabled = bool })) + default = [{ enabled = true }] } -`, - }) - require.Len(t, modules, 1) - functions := modules.GetResourcesByType("aws_lambda_function") - assert.Len(t, functions, 1) - vpcConfigs := functions[0].GetBlocks("vpc_config") - assert.Empty(t, vpcConfigs) - }) +resource "test_resource" "test" { + dynamic "foo" { + for_each = var.test_var - t.Run("arg is list of bool", func(t *testing.T) { - modules := parse(t, map[string]string{ - "main.tf": ` -resource "aws_lambda_function" "analyzer" { - dynamic "vpc_config" { - for_each = [true] - content {} + content { + bar = foo.value.enabled + } + } +}`, + expected: []any{true}, + }, + { + name: "attribute ref to object key", + src: `variable "some_var" { + type = map( + object({ + tag = string + }) + ) + default = { + ssh = { "tag" = "login" } + http = { "tag" = "proxy" } + https = { "tag" = "proxy" } } } -`, - }) - require.Len(t, modules, 1) - - functions := modules.GetResourcesByType("aws_lambda_function") - assert.Len(t, functions, 1) - vpcConfigs := functions[0].GetBlocks("vpc_config") - assert.Len(t, vpcConfigs, 1) - }) - t.Run("arg is list of objects", func(t *testing.T) { - modules := parse(t, map[string]string{ - "main.tf": `locals { - cluster_network_policy = [{ - enabled = true - }] +resource "test_resource" "test" { + dynamic "foo" { + for_each = { for name, values in var.some_var : name => values } + content { + bar = foo.key + } + } +}`, + expected: []any{"ssh", "http", "https"}, + }, + { + name: "attribute ref to object value", + src: `variable "some_var" { + type = map( + object({ + tag = string + }) + ) + default = { + ssh = { "tag" = "login" } + http = { "tag" = "proxy" } + https = { "tag" = "proxy" } + } } -resource "google_container_cluster" "primary" { - name = "test" +resource "test_resource" "test" { + dynamic "foo" { + for_each = { for name, values in var.some_var : name => values } + content { + bar = foo.value.tag + } + } +}`, + expected: []any{"login", "proxy", "proxy"}, + }, + { + name: "attribute ref to map key", + src: `variable "some_var" { + type = map + default = { + ssh = { "tag" = "login" } + http = { "tag" = "proxy" } + https = { "tag" = "proxy" } + } +} - dynamic "network_policy" { - for_each = local.cluster_network_policy +resource "test_resource" "test" { + dynamic "foo" { + for_each = var.some_var + content { + bar = foo.key + } + } +}`, + expected: []any{"ssh", "http", "https"}, + }, + { + name: "attribute ref to map value", + src: `variable "some_var" { + type = map + default = { + ssh = { "tag" = "login" } + http = { "tag" = "proxy" } + https = { "tag" = "proxy" } + } +} +resource "test_resource" "test" { + dynamic "foo" { + for_each = var.some_var content { - enabled = network_policy.value.enabled + bar = foo.value.tag } } }`, - }) - require.Len(t, modules, 1) + expected: []any{"login", "proxy", "proxy"}, + }, + { + name: "dynamic block with iterator", + src: `resource "test_resource" "test" { + dynamic "foo" { + for_each = ["foo", "bar"] + iterator = some_iterator + content { + bar = some_iterator.value + } + } +}`, + expected: []any{"foo", "bar"}, + }, + { + name: "iterator and parent block with same name", + src: `resource "test_resource" "test" { + dynamic "foo" { + for_each = ["foo", "bar"] + iterator = foo + content { + bar = foo.value + } + } +}`, + expected: []any{"foo", "bar"}, + }, + { + name: "for-each use null value", + src: `resource "test_resource" "test" { + dynamic "foo" { + for_each = null + content { + bar = foo.value + } + } +}`, + expected: []any{}, + }, + { + name: "no for-each attribute", + src: `resource "test_resource" "test" { + dynamic "foo" { + content { + bar = foo.value + } + } +}`, + expected: []any{}, + }, + } - clusters := modules.GetResourcesByType("google_container_cluster") - assert.Len(t, clusters, 1) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + modules := parse(t, map[string]string{ + "main.tf": tt.src, + }) + require.Len(t, modules, 1) - networkPolicies := clusters[0].GetBlocks("network_policy") - assert.Len(t, networkPolicies, 1) + resource := modules.GetResourcesByType("test_resource") + require.Len(t, resource, 1) + blocks := resource[0].GetBlocks("foo") - enabled := networkPolicies[0].GetAttribute("enabled") - assert.True(t, enabled.Value().True()) - }) + var vals []any + for _, attr := range blocks { + vals = append(vals, attr.GetAttribute("bar").GetRawValue()) + } - t.Run("nested dynamic", func(t *testing.T) { - modules := parse(t, map[string]string{ - "main.tf": ` -resource "test_block" "this" { - name = "name" - location = "loc" - dynamic "env" { - for_each = ["1", "2"] - content { - dynamic "value_source" { - for_each = [true, true] - content {} - } + assert.ElementsMatch(t, tt.expected, vals) + }) } +} + +func TestNestedDynamicBlock(t *testing.T) { + modules := parse(t, map[string]string{ + "main.tf": `resource "test_resource" "test" { + dynamic "foo" { + for_each = ["1", "1"] + content { + dynamic "bar" { + for_each = [true, true] + content { + baz = foo.value + qux = bar.value + } + } + } } }`, - }) - require.Len(t, modules, 1) + }) + require.Len(t, modules, 1) - testResources := modules.GetResourcesByType("test_block") - assert.Len(t, testResources, 1) - envs := testResources[0].GetBlocks("env") - assert.Len(t, envs, 2) + testResources := modules.GetResourcesByType("test_resource") + assert.Len(t, testResources, 1) + blocks := testResources[0].GetBlocks("foo") + assert.Len(t, blocks, 2) - var sources []*terraform.Block - for _, env := range envs { - sources = append(sources, env.GetBlocks("value_source")...) + var nested []*terraform.Block + for _, block := range blocks { + nested = append(nested, block.GetBlocks("bar")...) + for _, b := range nested { + assert.Equal(t, "1", b.GetAttribute("baz").GetRawValue()) + assert.Equal(t, true, b.GetAttribute("qux").GetRawValue()) } - assert.Len(t, sources, 4) - }) + } + assert.Len(t, nested, 4) } func parse(t *testing.T, files map[string]string) terraform.Modules { @@ -1513,21 +1658,6 @@ func parse(t *testing.T, files map[string]string) terraform.Modules { return modules } -func compareSets(a, b []int) bool { - m := make(map[int]bool) - for _, el := range a { - m[el] = true - } - - for _, el := range b { - if !m[el] { - return false - } - } - - return true -} - func TestModuleRefersToOutputOfAnotherModule(t *testing.T) { files := map[string]string{ "main.tf": ` @@ -1775,42 +1905,6 @@ variable "foo" {} assert.Equal(t, "bar", blocks[0].GetAttribute("foo").Value().AsString()) } -func TestDynamicWithIterator(t *testing.T) { - fsys := fstest.MapFS{ - "main.tf": &fstest.MapFile{ - Data: []byte(`resource "aws_s3_bucket" "this" { - dynamic versioning { - for_each = [true] - iterator = ver - - content { - enabled = ver.value - } - } -}`), - }, - } - - parser := New( - fsys, "", - OptionStopOnHCLError(true), - OptionWithDownloads(false), - ) - require.NoError(t, parser.ParseFS(context.TODO(), ".")) - - modules, _, err := parser.EvaluateAll(context.TODO()) - require.NoError(t, err) - - assert.Len(t, modules, 1) - - buckets := modules.GetResourcesByType("aws_s3_bucket") - assert.Len(t, buckets, 1) - - attr, _ := buckets[0].GetNestedAttribute("versioning.enabled") - - assert.True(t, attr.Value().True()) -} - func Test_AWSRegionNameDefined(t *testing.T) { fs := testutil.CreateFS(t, map[string]string{ diff --git a/pkg/iac/terraform/attribute.go b/pkg/iac/terraform/attribute.go index 39161a72a379..9c5536357283 100644 --- a/pkg/iac/terraform/attribute.go +++ b/pkg/iac/terraform/attribute.go @@ -729,15 +729,16 @@ func (a *Attribute) IsTrue() bool { if a == nil { return false } - switch a.Value().Type() { + val := a.Value() + switch val.Type() { case cty.Bool: - return a.Value().True() + return val.True() case cty.String: - val := a.Value().AsString() + val := val.AsString() val = strings.Trim(val, "\"") return strings.EqualFold(val, "true") case cty.Number: - val := a.Value().AsBigFloat() + val := val.AsBigFloat() f, _ := val.Float64() return f > 0 } diff --git a/pkg/iac/terraform/block.go b/pkg/iac/terraform/block.go index dbf7352959d0..dd6d8446c124 100644 --- a/pkg/iac/terraform/block.go +++ b/pkg/iac/terraform/block.go @@ -1,6 +1,7 @@ package terraform import ( + "errors" "fmt" "io/fs" "strconv" @@ -137,16 +138,15 @@ func (b *Block) GetRawValue() any { return nil } -func (b *Block) InjectBlock(block *Block, name string) { - block.hclBlock.Labels = []string{} - block.hclBlock.Type = name +func (b *Block) injectBlock(block *Block) { for attrName, attr := range block.Attributes() { - b.context.Root().SetByDot(attr.Value(), fmt.Sprintf("%s.%s.%s", b.reference.String(), name, attrName)) + path := fmt.Sprintf("%s.%s.%s", b.reference.String(), block.hclBlock.Type, attrName) + b.context.Root().SetByDot(attr.Value(), path) } b.childBlocks = append(b.childBlocks, block) } -func (b *Block) MarkExpanded() { +func (b *Block) markExpanded() { b.expanded = true } @@ -154,17 +154,26 @@ func (b *Block) IsExpanded() bool { return b.expanded } -func (b *Block) Clone(index cty.Value) *Block { - var childCtx *context.Context - if b.context != nil { - childCtx = b.context.NewChild() - } else { - childCtx = context.NewContext(&hcl.EvalContext{}, nil) +func (b *Block) inherit(ctx *context.Context, index ...cty.Value) *Block { + return NewBlock(b.copyBlock(), ctx, b.moduleBlock, b.parentBlock, b.moduleSource, b.moduleFS, index...) +} + +func (b *Block) copyBlock() *hcl.Block { + hclBlock := *b.hclBlock + return &hclBlock +} + +func (b *Block) childContext() *context.Context { + if b.context == nil { + return context.NewContext(&hcl.EvalContext{}, nil) } + return b.context.NewChild() +} - cloneHCL := *b.hclBlock +func (b *Block) Clone(index cty.Value) *Block { + childCtx := b.childContext() + clone := b.inherit(childCtx, index) - clone := NewBlock(&cloneHCL, childCtx, b.moduleBlock, b.parentBlock, b.moduleSource, b.moduleFS, index) if len(clone.hclBlock.Labels) > 0 { position := len(clone.hclBlock.Labels) - 1 labels := make([]string, len(clone.hclBlock.Labels)) @@ -188,7 +197,7 @@ func (b *Block) Clone(index cty.Value) *Block { } indexVal, _ := gocty.ToCtyValue(index, cty.Number) clone.context.SetByDot(indexVal, "count.index") - clone.MarkExpanded() + clone.markExpanded() b.cloneIndex++ return clone } @@ -446,6 +455,17 @@ func (b *Block) LocalName() string { return b.reference.String() } +func (b *Block) FullLocalName() string { + if b.parentBlock != nil { + return fmt.Sprintf( + "%s.%s", + b.parentBlock.FullLocalName(), + b.LocalName(), + ) + } + return b.LocalName() +} + func (b *Block) FullName() string { if b.moduleBlock != nil { @@ -576,3 +596,118 @@ func (b *Block) IsNil() bool { func (b *Block) IsNotNil() bool { return !b.IsNil() } + +func (b *Block) ExpandBlock() error { + var ( + expanded []*Block + errs []error + ) + + for _, child := range b.childBlocks { + if child.Type() == "dynamic" { + blocks, err := child.expandDynamic() + if err != nil { + errs = append(errs, err) + continue + } + expanded = append(expanded, blocks...) + } + } + + for _, block := range expanded { + b.injectBlock(block) + } + + return errors.Join(errs...) +} + +func (b *Block) expandDynamic() ([]*Block, error) { + if b.IsExpanded() || b.Type() != "dynamic" { + return nil, nil + } + + realBlockType := b.TypeLabel() + if realBlockType == "" { + return nil, errors.New("dynamic block must have 1 label") + } + + forEachVal, err := b.validateForEach() + if err != nil { + return nil, fmt.Errorf("invalid for-each in %s block: %w", b.FullLocalName(), err) + } + + var ( + expanded []*Block + errs []error + ) + + forEachVal.ForEachElement(func(key, val cty.Value) (stop bool) { + if val.IsNull() { + return + } + + iteratorName, err := b.iteratorName(realBlockType) + if err != nil { + errs = append(errs, err) + return + } + + forEachCtx := b.childContext() + obj := cty.ObjectVal(map[string]cty.Value{ + "key": key, + "value": val, + }) + forEachCtx.Set(obj, iteratorName) + + if content := b.GetBlock("content"); content != nil { + inherited := content.inherit(forEachCtx) + inherited.hclBlock.Labels = []string{} + inherited.hclBlock.Type = realBlockType + if err := inherited.ExpandBlock(); err != nil { + errs = append(errs, err) + return + } + expanded = append(expanded, inherited) + } + return + }) + + if len(expanded) > 0 { + b.markExpanded() + } + + return expanded, errors.Join(errs...) +} + +func (b *Block) validateForEach() (cty.Value, error) { + forEachAttr := b.GetAttribute("for_each") + if forEachAttr == nil { + return cty.NilVal, errors.New("for_each attribute required") + } + + forEachVal := forEachAttr.Value() + + if !forEachVal.CanIterateElements() { + return cty.NilVal, fmt.Errorf("cannot use a %s value in for_each. An iterable collection is required", forEachVal.GoString()) + } + + return forEachVal, nil +} + +func (b *Block) iteratorName(blockType string) (string, error) { + iteratorAttr := b.GetAttribute("iterator") + if iteratorAttr == nil { + return blockType, nil + } + + traversal, diags := hcl.AbsTraversalForExpr(iteratorAttr.hclAttribute.Expr) + if diags.HasErrors() { + return "", diags + } + + if len(traversal) != 1 { + return "", fmt.Errorf("dynamic iterator must be a single variable name") + } + + return traversal.RootName(), nil +}