diff --git a/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go b/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go index d5ad4a5c1858..494314f23582 100644 --- a/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go +++ b/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go @@ -923,6 +923,5 @@ func (al *addressList) hasNext() bool { // fields that are meaningful to the SubConn. func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool { return a.Addr == b.Addr && a.ServerName == b.ServerName && - a.Attributes.Equal(b.Attributes) && - a.Metadata == b.Metadata + a.Attributes.Equal(b.Attributes) } diff --git a/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go index a876fcb02f79..160783fc27d3 100644 --- a/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go +++ b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go @@ -41,6 +41,7 @@ import ( "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/testutils/pickfirst" "google.golang.org/grpc/internal/testutils/stats" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/status" @@ -1424,6 +1425,85 @@ func (s) TestPickFirstLeaf_HealthUpdates(t *testing.T) { testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) } +// Tests the case where an address update received by the pick_first LB policy +// differs in metadata which should be ignored by the LB policy. In this case, +// the test verifies that new connections are not created when the address +// update only changes the metadata. +func (s) TestPickFirstLeaf_AddressUpdateWithMetadata(t *testing.T) { + dialer := testutils.NewBlockingDialer() + dopts := []grpc.DialOption{ + grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, pickfirstleaf.Name)), + grpc.WithContextDialer(dialer.DialContext), + } + cc, r, backends := setupPickFirstLeaf(t, 2, dopts...) + + // Add a metadata to the addresses before pushing them to the pick_first LB + // policy through the manual resolver. + addrs := backends.resolverAddrs() + for i := range addrs { + addrs[i].Metadata = &metadata.MD{ + "test-metadata-1": []string{fmt.Sprintf("%d", i)}, + } + } + r.UpdateState(resolver.State{Addresses: addrs}) + + // Ensure that RPCs succeed to the expected backend. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + + // Create holds for each backend. This will be used to verify the connection + // is not re-established. + holds := backends.holds(dialer) + + // Add metadata to the addresses before pushing them to the pick_first LB + // policy through the manual resolver. Leave the order of the addresses + // unchanged. + for i := range addrs { + addrs[i].Metadata = &metadata.MD{ + "test-metadata-2": []string{fmt.Sprintf("%d", i)}, + } + } + r.UpdateState(resolver.State{Addresses: addrs}) + + // Ensure that no new connection is established. + for i := range holds { + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if holds[i].Wait(sCtx) { + t.Fatalf("Unexpected connection attempt to backend: %s", addrs[i]) + } + } + + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + + // Add metadata to the addresses before pushing them to the pick_first LB + // policy through the manual resolver. Reverse of the order of addresses. + for i := range addrs { + addrs[i].Metadata = &metadata.MD{ + "test-metadata-3": []string{fmt.Sprintf("%d", i)}, + } + } + addrs[0], addrs[1] = addrs[1], addrs[0] + r.UpdateState(resolver.State{Addresses: addrs}) + + // Ensure that no new connection is established. + for i := range holds { + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if holds[i].Wait(sCtx) { + t.Fatalf("Unexpected connection attempt to backend: %s", addrs[i]) + } + } + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { + t.Fatal(err) + } +} + // healthListenerCapturingCCWrapper is used to capture the health listener so // that health updates can be mocked for testing. type healthListenerCapturingCCWrapper struct {