diff --git a/main.tf b/main.tf index db43566..bdef9a3 100644 --- a/main.tf +++ b/main.tf @@ -5,7 +5,7 @@ locals { transit_gateway_enabled = local.enabled && var.transit_gateway_enabled transit_gateway_attachment_id = join("", aws_vpn_connection.default[*].transit_gateway_attachment_id) - vpn_gateway_id = join("", aws_vpn_gateway.default[*].id) + vpn_gateway_id = one(aws_vpn_gateway.default[*].id) customer_gateway_id = join("", aws_customer_gateway.default[*].id) vpn_connection_id = join("", aws_vpn_connection.default[*].id) } @@ -20,11 +20,12 @@ resource "aws_vpn_gateway" "default" { # https://www.terraform.io/docs/providers/aws/r/customer_gateway.html resource "aws_customer_gateway" "default" { - count = local.enabled && var.customer_gateway_ip_address != null ? 1 : 0 - bgp_asn = var.customer_gateway_bgp_asn - ip_address = var.customer_gateway_ip_address - type = "ipsec.1" - tags = module.this.tags + count = local.enabled && var.customer_gateway_ip_address != null ? 1 : 0 + device_name = module.this.id + bgp_asn = var.customer_gateway_bgp_asn + ip_address = var.customer_gateway_ip_address + type = "ipsec.1" + tags = module.this.tags } module "logs" { @@ -96,7 +97,7 @@ resource "aws_vpn_connection" "default" { # https://www.terraform.io/docs/providers/aws/r/vpn_gateway_route_propagation.html resource "aws_vpn_gateway_route_propagation" "default" { - count = local.enabled ? length(var.route_table_ids) : 0 + count = local.enabled && !var.transit_gateway_enabled ? length(var.route_table_ids) : 0 vpn_gateway_id = local.vpn_gateway_id route_table_id = element(var.route_table_ids, count.index) } diff --git a/test/src/examples_complete_test.go b/test/src/examples_complete_test.go index d36fd26..f529254 100644 --- a/test/src/examples_complete_test.go +++ b/test/src/examples_complete_test.go @@ -2,6 +2,7 @@ package test import ( "os" + "os/exec" "strings" "testing" @@ -20,11 +21,26 @@ func cleanup(t *testing.T, terraformOptions *terraform.Options, tempTestFolder s os.RemoveAll(tempTestFolder) } +func detectPlatform() string { + cmd := exec.Command("terraform", "--version") + out, _ := cmd.CombinedOutput() + platform := "" + if strings.Contains(string(out), "Terraform") { + platform = "tf" + } else if strings.Contains(string(out), "OpenTofu") { + platform = "tofu" + } else { + platform = "unknown" + } + return platform +} + // Test the Terraform module in examples/complete using Terratest. func TestExamplesComplete(t *testing.T) { t.Parallel() randID := strings.ToLower(random.UniqueId()) - attributes := []string{randID} + platform := detectPlatform() + attributes := []string{randID, platform} rootFolder := "../../" terraformFolderRelativeToRoot := "examples/complete" @@ -61,7 +77,8 @@ func TestExamplesComplete(t *testing.T) { func TestExamplesCompleteDisabled(t *testing.T) { t.Parallel() randID := strings.ToLower(random.UniqueId()) - attributes := []string{randID} + platform := detectPlatform() + attributes := []string{randID, platform} rootFolder := "../../" terraformFolderRelativeToRoot := "examples/complete"