From 237fa47f39aa438100ddadb72d4d7a4be8a967b0 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 26 Dec 2024 16:42:07 +0000 Subject: [PATCH 1/8] add DAC to live managed ID test --- .../testdata/managed-id-test/main.go | 54 +++++++++++++++++-- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/sdk/azidentity/testdata/managed-id-test/main.go b/sdk/azidentity/testdata/managed-id-test/main.go index 68356f8907bb..a6aed34fba02 100644 --- a/sdk/azidentity/testdata/managed-id-test/main.go +++ b/sdk/azidentity/testdata/managed-id-test/main.go @@ -9,9 +9,12 @@ import ( "log" "net/http" "os" + "regexp" "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" ) @@ -37,15 +40,31 @@ var ( workloadID: os.Getenv("AZIDENTITY_USE_WORKLOAD_IDENTITY") != "", } + // jwtRegex is used to redact JWTs (e.g. access tokens) in log output sent to a test client, although + // that output should never contain tokens because it's sent only when a test fails i.e., the request + // handler couldn't obtain an access token + jwtRegex = regexp.MustCompile(`ey\S+\.\S+\.\S+`) + logOptions = policy.LogOptions{ + AllowedQueryParams: []string{"client_id", "mi_res_id", "msi_res_id", "object_id", "principal_id", "resource"}, + IncludeBody: true, + } + // logs collects log output from a test run to help debug failures. Note that its usage isn't + // concurrency-safe and that's okay because live managed identity tests targeting this server + // don't send concurrent requests. + logs strings.Builder missingConfig string ) func credential(id azidentity.ManagedIDKind) (azcore.TokenCredential, error) { + co := azcore.ClientOptions{Logging: logOptions} if config.workloadID { // the identity is determined by service account configuration - return azidentity.NewWorkloadIdentityCredential(nil) + return azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{ClientOptions: co}) } - return azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ID: id}) + return azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ + ClientOptions: co, + ID: id, + }) } func listContainers(account string, cred azcore.TokenCredential) error { @@ -59,6 +78,7 @@ func listContainers(account string, cred azcore.TokenCredential) error { } func handler(w http.ResponseWriter, r *http.Request) { + logs.Reset() log.Print("received a request") if missingConfig != "" { w.WriteHeader(http.StatusInternalServerError) @@ -68,6 +88,11 @@ func handler(w http.ResponseWriter, r *http.Request) { cred, err := credential(nil) if err == nil { + name := "ManagedIdentityCredential" + if config.workloadID { + name = "WorkloadIdentityCredential" + } + logs.WriteString("\n*** testing " + name + "\n\n") err = listContainers(config.storageName, cred) } if err == nil && !config.workloadID { @@ -82,17 +107,38 @@ func handler(w http.ResponseWriter, r *http.Request) { } } + if err == nil { + // discard logs from the successful tests above + logs.Reset() + logs.WriteString("*** testing DefaultAzureCredential\n\n") + cred, err = azidentity.NewDefaultAzureCredential( + &azidentity.DefaultAzureCredentialOptions{ + ClientOptions: azcore.ClientOptions{Logging: logOptions}, + }, + ) + if err == nil { + err = listContainers(config.storageName, cred) + } + } + if err == nil { fmt.Fprint(w, "test passed") log.Print("test passed") } else { w.WriteHeader(http.StatusInternalServerError) - fmt.Fprint(w, err) - log.Print(err) + logs.WriteString("\n*** test failed with error: " + err.Error() + "\n") + fmt.Fprint(w, logs.String()) + log.Print(logs.String()) } } func main() { + azlog.SetListener(func(_ azlog.Event, msg string) { + msg = jwtRegex.ReplaceAllString(msg, "***") + logs.WriteString(msg + "\n\n") + }) + azlog.SetEvents(azidentity.EventAuthentication, azlog.EventRequest, azlog.EventResponse) + v := []string{} if config.storageName == "" { v = append(v, "AZIDENTITY_STORAGE_NAME") From d6349bc2abd3a995b56cdbafa7a73e3f5d63c37d Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 27 Dec 2024 15:37:03 +0000 Subject: [PATCH 2/8] always return 200 so wget always outputs response bodies --- sdk/azidentity/testdata/managed-id-test/main.go | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/sdk/azidentity/testdata/managed-id-test/main.go b/sdk/azidentity/testdata/managed-id-test/main.go index a6aed34fba02..b14dde7c1ee9 100644 --- a/sdk/azidentity/testdata/managed-id-test/main.go +++ b/sdk/azidentity/testdata/managed-id-test/main.go @@ -81,7 +81,6 @@ func handler(w http.ResponseWriter, r *http.Request) { logs.Reset() log.Print("received a request") if missingConfig != "" { - w.WriteHeader(http.StatusInternalServerError) fmt.Fprint(w, "need a value for "+missingConfig) return } @@ -121,15 +120,13 @@ func handler(w http.ResponseWriter, r *http.Request) { } } - if err == nil { - fmt.Fprint(w, "test passed") - log.Print("test passed") - } else { - w.WriteHeader(http.StatusInternalServerError) + msg := "test passed" + if err != nil { logs.WriteString("\n*** test failed with error: " + err.Error() + "\n") - fmt.Fprint(w, logs.String()) - log.Print(logs.String()) + msg = logs.String() } + fmt.Fprint(w, msg) + log.Print(msg) } func main() { From 94d1dacd6160d9addb34a8eab5a69f3708c0e4d9 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Sat, 11 Jan 2025 01:18:42 +0000 Subject: [PATCH 3/8] Credential chains continue iterating after unexpected IMDS responses --- sdk/azidentity/CHANGELOG.md | 4 + .../default_azure_credential_test.go | 37 +++++---- sdk/azidentity/managed_identity_client.go | 33 ++++---- .../managed_identity_client_test.go | 79 +++++++++++++++++++ 4 files changed, 120 insertions(+), 33 deletions(-) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index b25fff498cbc..eb0d3cd27688 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -8,8 +8,12 @@ ### Bugs Fixed * User credential types inconsistently log access token scopes +* `DefaultAzureCredential` skips managed identity in Azure Container Instances ### Other Changes +* `ChainedTokenCredential` and `DefaultAzureCredential` continue to their next + credential after `ManagedIdentityCredential` receives an unexpected response + from IMDS, indicating the response is from something else such as a proxy ## 1.8.0 (2024-10-08) diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go index f8dd3b6d806c..d69591f82024 100644 --- a/sdk/azidentity/default_azure_credential_test.go +++ b/sdk/azidentity/default_azure_credential_test.go @@ -335,27 +335,26 @@ func TestDefaultAzureCredential_IMDS(t *testing.T) { require.True(t, probed) require.Equal(t, tokenValue, tk.Token) - t.Run("non-JSON response", func(t *testing.T) { + t.Run("Azure Container Instances", func(t *testing.T) { + // ensure GetToken returns an error if DefaultAzureCredential skips managed identity before := defaultAzTokenProvider defer func() { defaultAzTokenProvider = before }() - defaultAzTokenProvider = mockAzTokenProviderSuccess - for _, res := range [][]mock.ResponseOption{ - {mock.WithStatusCode(http.StatusNotFound)}, - {mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusBadRequest)}, - {mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusOK)}, - } { - srv, close := mock.NewTLSServer(mock.WithTransformAllRequestsToTestServerUrl()) - defer close() - srv.SetResponse(res...) - cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ - ClientOptions: policy.ClientOptions{ - Transport: srv, - }, - }) - require.NoError(t, err) - _, err = cred.GetToken(ctx, testTRO) - require.NoError(t, err, "DefaultAzureCredential should continue after receiving a non-JSON response from IMDS") - } + defaultAzTokenProvider = mockAzTokenProviderFailure + + srv, close := mock.NewTLSServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.AppendResponse(mock.WithBody([]byte("Required metadata header not specified or not correct")), mock.WithStatusCode(http.StatusBadRequest)) + srv.AppendResponse(mock.WithBody(accessTokenRespSuccess), mock.WithStatusCode(http.StatusOK)) + + cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ + ClientOptions: policy.ClientOptions{ + Transport: srv, + }, + }) + require.NoError(t, err) + tk, err := cred.GetToken(ctx, testTRO) + require.NoError(t, err, "DefaultAzureCredential should accept ACI's response to the probe request") + require.Equal(t, tokenValue, tk.Token) }) }) diff --git a/sdk/azidentity/managed_identity_client.go b/sdk/azidentity/managed_identity_client.go index 260c04355041..6490e0639627 100644 --- a/sdk/azidentity/managed_identity_client.go +++ b/sdk/azidentity/managed_identity_client.go @@ -215,6 +215,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi // no need to synchronize around this value because it's true only when DefaultAzureCredential constructed the client, // and in that case ChainedTokenCredential.GetToken synchronizes goroutines that would execute this block if c.probeIMDS { + // send a malformed request (no Metadata header) to IMDS to determine whether the endpoint is available cx, cancel := context.WithTimeout(ctx, imdsProbeTimeout) defer cancel() cx = policy.WithRetryOptions(cx, policy.RetryOptions{MaxRetries: -1}) @@ -222,24 +223,14 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi if err != nil { return azcore.AccessToken{}, fmt.Errorf("failed to create IMDS probe request: %s", err) } - res, err := c.azClient.Pipeline().Do(req) - if err != nil { + if _, err = c.azClient.Pipeline().Do(req); err != nil { msg := err.Error() if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { msg = "managed identity timed out. See https://aka.ms/azsdk/go/identity/troubleshoot#dac for more information" } return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, msg) } - // because IMDS always responds with JSON, assume a non-JSON response is from something else, such - // as a proxy, and return credentialUnavailableError so DefaultAzureCredential continues iterating - b, err := azruntime.Payload(res) - if err != nil { - return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, fmt.Sprintf("failed to read IMDS probe response: %s", err)) - } - if !json.Valid(b) { - return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, "unexpected response to IMDS probe") - } - // send normal token requests from now on because IMDS responded + // send normal token requests from now on because someothing responded c.probeIMDS = false } @@ -254,13 +245,21 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi } if azruntime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) { - return c.createAccessToken(resp) + tk, err := c.createAccessToken(resp) + if err != nil && c.msiType == msiTypeIMDS { + // failure to unmarshal a 2xx implies the response is from something other than IMDS such as a proxy listening at + // the same address. Return a credentialUnavailableError so credential chains continue to their next credential + err = newCredentialUnavailableError(credNameManagedIdentity, err.Error()) + } + return tk, err } if c.msiType == msiTypeIMDS { switch resp.StatusCode { case http.StatusBadRequest: if id != nil { + // return authenticationFailedError, halting any encompassing credential chain, + // because the explicit user-assigned identity implies the developer expected this to work return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "the requested identity isn't assigned to this resource", resp) } msg := "failed to authenticate a system assigned identity" @@ -275,6 +274,12 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi if err == nil && strings.Contains(string(body), "unreachable") { return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, fmt.Sprintf("unexpected response %q", string(body))) } + default: + // the response may be from something other than IMDS, for example a proxy returning + // 404. Return credentialUnavailableError so credential chains continue to their + // next credential, include the response in the error message to help debugging + err = newAuthenticationFailedError(credNameManagedIdentity, "", resp) + return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, err.Error()) } } @@ -290,7 +295,7 @@ func (c *managedIdentityClient) createAccessToken(res *http.Response) (azcore.Ac ExpiresOn interface{} `json:"expires_on,omitempty"` // the value returned in this field varies between a number and a date string }{} if err := azruntime.UnmarshalAsJSON(res, &value); err != nil { - return azcore.AccessToken{}, fmt.Errorf("internal AccessToken: %v", err) + return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "Unexpected response content", res) } if value.ExpiresIn != "" { expiresIn, err := json.Number(value.ExpiresIn).Int64() diff --git a/sdk/azidentity/managed_identity_client_test.go b/sdk/azidentity/managed_identity_client_test.go index 78d54c4dcb42..1c9f08393532 100644 --- a/sdk/azidentity/managed_identity_client_test.go +++ b/sdk/azidentity/managed_identity_client_test.go @@ -18,6 +18,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" + "github.com/stretchr/testify/require" ) type userAgentValidatingPolicy struct { @@ -97,6 +98,11 @@ func TestManagedIdentityClient_IMDSErrors(t *testing.T) { code: http.StatusForbidden, body: "connecting to 169.254.169.254:80: connecting to 169.254.169.254:80: dial tcp 169.254.169.254:80: connectex: A socket operation was attempted to an unreachable host.", }, + { + desc: "Azure Container Instances", + code: http.StatusBadRequest, + body: "Required metadata header not specified or not correct", + }, } { t.Run(fmt.Sprint(test.code), func(t *testing.T) { srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) @@ -122,3 +128,76 @@ func TestManagedIdentityClient_IMDSErrors(t *testing.T) { }) } } + +func TestManagedIdentityClient_IMDSProbeReturnsUnavailableError(t *testing.T) { + for _, test := range []struct { + desc string + res [][]mock.ResponseOption + }{ + { + "Azure Container Instance", + [][]mock.ResponseOption{ + { + mock.WithBody([]byte("Required metadata header not specified or not correct")), + mock.WithStatusCode(http.StatusBadRequest), + }, + {mock.WithBody([]byte("error")), mock.WithStatusCode(http.StatusBadRequest)}, + }, + }, + { + "404", + [][]mock.ResponseOption{ + {mock.WithStatusCode(http.StatusNotFound)}, + {mock.WithStatusCode(http.StatusNotFound)}, + }, + }, + { + "non-JSON token response", + [][]mock.ResponseOption{ + {mock.WithBody([]byte("Required metadata header not specified or not correct"))}, + {mock.WithBody([]byte("not json"))}, + }, + }, + { + "no token in response", + [][]mock.ResponseOption{ + {mock.WithBody([]byte("Required metadata header not specified or not correct"))}, + {mock.WithBody([]byte(`{"error": "no token here"}`)), mock.WithStatusCode(http.StatusBadRequest)}, + }, + }, + } { + t.Run(test.desc, func(t *testing.T) { + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + for _, res := range test.res { + srv.AppendResponse(res...) + } + c, err := newManagedIdentityClient(&ManagedIdentityCredentialOptions{ + ClientOptions: azcore.ClientOptions{ + // disabling retries simplifies the 404 test (404 is retriable for IMDS) + Retry: policy.RetryOptions{MaxRetries: -1}, + Transport: srv, + }, + dac: true, + }) + require.NoError(t, err) + _, err = c.authenticate(ctx, nil, testTRO.Scopes) + var cu credentialUnavailable + require.ErrorAs(t, err, &cu) + + srv.AppendResponse( + mock.WithBody(accessTokenRespSuccess), + mock.WithPredicate(func(r *http.Request) bool { + if _, ok := r.Header["Metadata"]; !ok { + t.Error("client shouldn't send another probe after receiving a response") + } + return true + }), + ) + srv.AppendResponse() + tk, err := c.authenticate(ctx, nil, testTRO.Scopes) + require.NoError(t, err) + require.Equal(t, tokenValue, tk.Token) + }) + } +} From 78614aee62da8124f89ce4024415ecc939fde4fa Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Mon, 13 Jan 2025 18:55:31 +0000 Subject: [PATCH 4/8] refactor tests --- .../default_azure_credential_test.go | 80 ++++++++++- .../managed_identity_client_test.go | 128 ------------------ 2 files changed, 78 insertions(+), 130 deletions(-) diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go index d69591f82024..721bfd4753ad 100644 --- a/sdk/azidentity/default_azure_credential_test.go +++ b/sdk/azidentity/default_azure_credential_test.go @@ -10,7 +10,6 @@ import ( "context" "errors" "fmt" - "io" "net/http" "os" "path/filepath" @@ -322,7 +321,6 @@ func TestDefaultAzureCredential_IMDS(t *testing.T) { probed = true require.Empty(t, hdr, "probe request shouldn't have Metadata header") return &http.Response{ - Body: io.NopCloser(strings.NewReader("{}")), StatusCode: http.StatusInternalServerError, } }, @@ -396,6 +394,84 @@ func TestDefaultAzureCredential_IMDS(t *testing.T) { }) } +func TestDefaultAzureCredential_UnexpectedIMDSResponse(t *testing.T) { + before := defaultAzTokenProvider + defer func() { defaultAzTokenProvider = before }() + defaultAzTokenProvider = mockAzTokenProviderSuccess + + const dockerDesktopPrefix = "connecting to 169.254.169.254:80: connecting to 169.254.169.254:80: dial tcp 169.254.169.254:80: connectex: A socket operation was attempted to an unreachable " + for _, test := range []struct { + desc string + res [][]mock.ResponseOption + }{ + { + "Docker Desktop", + [][]mock.ResponseOption{ + { + mock.WithBody([]byte(dockerDesktopPrefix + "host.")), + mock.WithStatusCode(http.StatusForbidden), + }, + { + mock.WithBody([]byte(dockerDesktopPrefix + "host.")), + mock.WithStatusCode(http.StatusForbidden), + }, + }, + }, + { + "Docker Desktop", + [][]mock.ResponseOption{ + { + mock.WithBody([]byte(dockerDesktopPrefix + "network.")), + mock.WithStatusCode(http.StatusForbidden), + }, + { + mock.WithBody([]byte(dockerDesktopPrefix + "network.")), + mock.WithStatusCode(http.StatusForbidden), + }, + }, + }, + { + "IMDS: no identity assigned", + [][]mock.ResponseOption{ + {mock.WithStatusCode(http.StatusBadRequest)}, + { + mock.WithBody([]byte(`{"error":"invalid_request","error_description":"Identity not found"}`)), + mock.WithStatusCode(http.StatusBadRequest), + }, + }, + }, + { + "no token in response", + [][]mock.ResponseOption{ + {mock.WithStatusCode(http.StatusOK)}, + {mock.WithBody([]byte(`{"error": "no token here"}`)), mock.WithStatusCode(http.StatusOK)}, + }, + }, + { + "non-JSON token response", + [][]mock.ResponseOption{ + {mock.WithStatusCode(http.StatusOK)}, + {mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusOK)}, + }, + }, + } { + t.Run(test.desc, func(t *testing.T) { + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + for _, res := range test.res { + srv.AppendResponse(res...) + } + c, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ + ClientOptions: policy.ClientOptions{Transport: srv}, + }) + require.NoError(t, err) + tk, err := c.GetToken(ctx, testTRO) + require.NoError(t, err, "expected a token from AzureCLICredential") + require.Equal(t, tokenValue, tk.Token, "expected a token from AzureCLICredential") + }) + } +} + func TestDefaultAzureCredential_UnsupportedMIClientID(t *testing.T) { fail := true before := defaultAzTokenProvider diff --git a/sdk/azidentity/managed_identity_client_test.go b/sdk/azidentity/managed_identity_client_test.go index 1c9f08393532..aa604b266269 100644 --- a/sdk/azidentity/managed_identity_client_test.go +++ b/sdk/azidentity/managed_identity_client_test.go @@ -8,8 +8,6 @@ package azidentity import ( "context" - "errors" - "fmt" "net/http" "net/url" "strings" @@ -17,8 +15,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" - "github.com/stretchr/testify/require" ) type userAgentValidatingPolicy struct { @@ -77,127 +73,3 @@ func TestManagedIdentityClient_ApplicationID(t *testing.T) { t.Fatal(err) } } - -func TestManagedIdentityClient_IMDSErrors(t *testing.T) { - for _, test := range []struct { - body, desc string - code int - }{ - { - desc: "No identity assigned", - code: http.StatusBadRequest, - body: `{"error":"invalid_request","error_description":"Identity not found"}`, - }, - { - desc: "Docker Desktop", - code: http.StatusForbidden, - body: "connecting to 169.254.169.254:80: connecting to 169.254.169.254:80: dial tcp 169.254.169.254:80: connectex: A socket operation was attempted to an unreachable network.", - }, - { - desc: "Docker Desktop", - code: http.StatusForbidden, - body: "connecting to 169.254.169.254:80: connecting to 169.254.169.254:80: dial tcp 169.254.169.254:80: connectex: A socket operation was attempted to an unreachable host.", - }, - { - desc: "Azure Container Instances", - code: http.StatusBadRequest, - body: "Required metadata header not specified or not correct", - }, - } { - t.Run(fmt.Sprint(test.code), func(t *testing.T) { - srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) - defer close() - srv.SetResponse(mock.WithBody([]byte(test.body)), mock.WithStatusCode(test.code)) - client, err := newManagedIdentityClient(&ManagedIdentityCredentialOptions{ - ClientOptions: azcore.ClientOptions{Transport: srv}, - }) - if err != nil { - t.Fatal(err) - } - _, err = client.authenticate(context.Background(), nil, testTRO.Scopes) - if err == nil { - t.Fatal("expected an error") - } - if actual := err.Error(); !strings.Contains(actual, test.body) { - t.Fatalf("expected response body in error, got %q", actual) - } - var unavailableErr credentialUnavailable - if !errors.As(err, &unavailableErr) { - t.Fatalf("expected %T, got %T", unavailableErr, err) - } - }) - } -} - -func TestManagedIdentityClient_IMDSProbeReturnsUnavailableError(t *testing.T) { - for _, test := range []struct { - desc string - res [][]mock.ResponseOption - }{ - { - "Azure Container Instance", - [][]mock.ResponseOption{ - { - mock.WithBody([]byte("Required metadata header not specified or not correct")), - mock.WithStatusCode(http.StatusBadRequest), - }, - {mock.WithBody([]byte("error")), mock.WithStatusCode(http.StatusBadRequest)}, - }, - }, - { - "404", - [][]mock.ResponseOption{ - {mock.WithStatusCode(http.StatusNotFound)}, - {mock.WithStatusCode(http.StatusNotFound)}, - }, - }, - { - "non-JSON token response", - [][]mock.ResponseOption{ - {mock.WithBody([]byte("Required metadata header not specified or not correct"))}, - {mock.WithBody([]byte("not json"))}, - }, - }, - { - "no token in response", - [][]mock.ResponseOption{ - {mock.WithBody([]byte("Required metadata header not specified or not correct"))}, - {mock.WithBody([]byte(`{"error": "no token here"}`)), mock.WithStatusCode(http.StatusBadRequest)}, - }, - }, - } { - t.Run(test.desc, func(t *testing.T) { - srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) - defer close() - for _, res := range test.res { - srv.AppendResponse(res...) - } - c, err := newManagedIdentityClient(&ManagedIdentityCredentialOptions{ - ClientOptions: azcore.ClientOptions{ - // disabling retries simplifies the 404 test (404 is retriable for IMDS) - Retry: policy.RetryOptions{MaxRetries: -1}, - Transport: srv, - }, - dac: true, - }) - require.NoError(t, err) - _, err = c.authenticate(ctx, nil, testTRO.Scopes) - var cu credentialUnavailable - require.ErrorAs(t, err, &cu) - - srv.AppendResponse( - mock.WithBody(accessTokenRespSuccess), - mock.WithPredicate(func(r *http.Request) bool { - if _, ok := r.Header["Metadata"]; !ok { - t.Error("client shouldn't send another probe after receiving a response") - } - return true - }), - ) - srv.AppendResponse() - tk, err := c.authenticate(ctx, nil, testTRO.Scopes) - require.NoError(t, err) - require.Equal(t, tokenValue, tk.Token) - }) - } -} From 7da1342f267da64a1dc76d9b9f6f88f203a4dc86 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Mon, 13 Jan 2025 22:01:46 +0000 Subject: [PATCH 5/8] compat: outside chains, return auth failed for unexpected responses --- sdk/azidentity/chained_token_credential.go | 3 +++ sdk/azidentity/managed_identity_client.go | 7 +++++-- .../managed_identity_credential_test.go | 15 +++++++++++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index 41e908a7c843..82342a02545d 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -49,6 +49,9 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine if source == nil { // cannot have a nil credential in the chain or else the application will panic when GetToken() is called on nil return nil, errors.New("sources cannot contain nil") } + if mc, ok := source.(*ManagedIdentityCredential); ok { + mc.mic.chained = true + } } cp := make([]azcore.TokenCredential, len(sources)) copy(cp, sources) diff --git a/sdk/azidentity/managed_identity_client.go b/sdk/azidentity/managed_identity_client.go index 6490e0639627..1a1cf026846f 100644 --- a/sdk/azidentity/managed_identity_client.go +++ b/sdk/azidentity/managed_identity_client.go @@ -60,7 +60,10 @@ const ( ) type managedIdentityClient struct { - azClient *azcore.Client + azClient *azcore.Client + // chained indicates whether the client is part of a credential chain. If true, the client will return + // a credentialUnavailableError instead of an AuthenticationFailedError for an unexpected IMDS response. + chained bool endpoint string id ManagedIDKind msiType msiType @@ -246,7 +249,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi if azruntime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) { tk, err := c.createAccessToken(resp) - if err != nil && c.msiType == msiTypeIMDS { + if err != nil && c.chained && c.msiType == msiTypeIMDS { // failure to unmarshal a 2xx implies the response is from something other than IMDS such as a proxy listening at // the same address. Return a credentialUnavailableError so credential chains continue to their next credential err = newCredentialUnavailableError(credNameManagedIdentity, err.Error()) diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go index 1940855bb157..c0458449b34a 100644 --- a/sdk/azidentity/managed_identity_credential_test.go +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -660,6 +660,21 @@ func TestManagedIdentityCredential_IMDSRetries(t *testing.T) { } } +func TestManagedIdentityCredential_UnexpectedIMDSResponse(t *testing.T) { + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.AppendResponse(mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusOK)) + + c, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ + ClientOptions: policy.ClientOptions{Transport: srv}, + }) + require.NoError(t, err) + + _, err = c.GetToken(ctx, testTRO) + var authFailed *AuthenticationFailedError + require.ErrorAs(t, err, &authFailed, "unexpected token response from IMDS should prompt an AuthenticationFailedError") +} + func TestManagedIdentityCredential_ServiceFabric(t *testing.T) { expectedSecret := "expected-secret" pred := func(req *http.Request) bool { From 5b1e2a10515155e23fc2dd2c3519d06577576be8 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Tue, 14 Jan 2025 00:34:44 +0000 Subject: [PATCH 6/8] align struct --- sdk/azidentity/managed_identity_client.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sdk/azidentity/managed_identity_client.go b/sdk/azidentity/managed_identity_client.go index 1a1cf026846f..cdad05f7eadc 100644 --- a/sdk/azidentity/managed_identity_client.go +++ b/sdk/azidentity/managed_identity_client.go @@ -60,14 +60,14 @@ const ( ) type managedIdentityClient struct { - azClient *azcore.Client - // chained indicates whether the client is part of a credential chain. If true, the client will return - // a credentialUnavailableError instead of an AuthenticationFailedError for an unexpected IMDS response. - chained bool + azClient *azcore.Client endpoint string id ManagedIDKind msiType msiType probeIMDS bool + // chained indicates whether the client is part of a credential chain. If true, the client will return + // a credentialUnavailableError instead of an AuthenticationFailedError for an unexpected IMDS response. + chained bool } // arcKeyDirectory returns the directory expected to contain Azure Arc keys From 2f89bd47c825d6f52a91b43b74d3acd5eb4d3788 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Tue, 14 Jan 2025 08:34:20 -0800 Subject: [PATCH 7/8] prevent error type change for retriable responses --- sdk/azidentity/managed_identity_client.go | 3 +- .../managed_identity_credential_test.go | 30 ++++++++++++++----- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/sdk/azidentity/managed_identity_client.go b/sdk/azidentity/managed_identity_client.go index cdad05f7eadc..de922f2f9184 100644 --- a/sdk/azidentity/managed_identity_client.go +++ b/sdk/azidentity/managed_identity_client.go @@ -277,7 +277,8 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi if err == nil && strings.Contains(string(body), "unreachable") { return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, fmt.Sprintf("unexpected response %q", string(body))) } - default: + } + if c.chained { // the response may be from something other than IMDS, for example a proxy returning // 404. Return credentialUnavailableError so credential chains continue to their // next credential, include the response in the error message to help debugging diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go index c0458449b34a..db09364f2680 100644 --- a/sdk/azidentity/managed_identity_credential_test.go +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -663,16 +663,30 @@ func TestManagedIdentityCredential_IMDSRetries(t *testing.T) { func TestManagedIdentityCredential_UnexpectedIMDSResponse(t *testing.T) { srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) defer close() - srv.AppendResponse(mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusOK)) + tests := [][]mock.ResponseOption{ + {mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusOK)}, + } + // credential should return AuthenticationFailedError when a token request ends with a retriable response + ro := policy.RetryOptions{} + setIMDSRetryOptionDefaults(&ro) + for _, c := range ro.StatusCodes { + tests = append(tests, []mock.ResponseOption{mock.WithStatusCode(c)}) + } + for _, res := range tests { + srv.AppendResponse(res...) - c, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ - ClientOptions: policy.ClientOptions{Transport: srv}, - }) - require.NoError(t, err) + c, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ + ClientOptions: policy.ClientOptions{ + Retry: policy.RetryOptions{MaxRetries: -1}, + Transport: srv, + }, + }) + require.NoError(t, err) - _, err = c.GetToken(ctx, testTRO) - var authFailed *AuthenticationFailedError - require.ErrorAs(t, err, &authFailed, "unexpected token response from IMDS should prompt an AuthenticationFailedError") + _, err = c.GetToken(ctx, testTRO) + var af *AuthenticationFailedError + require.ErrorAs(t, err, &af, "unexpected token response from IMDS should prompt an AuthenticationFailedError") + } } func TestManagedIdentityCredential_ServiceFabric(t *testing.T) { From 4baac5fec3465fe81e9b365c658fcef6a851a127 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Tue, 14 Jan 2025 08:34:44 -0800 Subject: [PATCH 8/8] typo --- sdk/azidentity/managed_identity_client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/azidentity/managed_identity_client.go b/sdk/azidentity/managed_identity_client.go index de922f2f9184..cc07fd70153a 100644 --- a/sdk/azidentity/managed_identity_client.go +++ b/sdk/azidentity/managed_identity_client.go @@ -233,7 +233,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi } return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, msg) } - // send normal token requests from now on because someothing responded + // send normal token requests from now on because something responded c.probeIMDS = false }