Skip to content

Commit

Permalink
Credential chains continue iterating after unexpected IMDS responses (
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Jan 14, 2025
1 parent 7d4721b commit 786b0be
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 93 deletions.
4 changes: 4 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@

### Bugs Fixed
* User credential types inconsistently log access token scopes
* `DefaultAzureCredential` skips managed identity in Azure Container Instances

### Other Changes
* `ChainedTokenCredential` and `DefaultAzureCredential` continue to their next
credential after `ManagedIdentityCredential` receives an unexpected response
from IMDS, indicating the response is from something else such as a proxy

## 1.8.0 (2024-10-08)

Expand Down
3 changes: 3 additions & 0 deletions sdk/azidentity/chained_token_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine
if source == nil { // cannot have a nil credential in the chain or else the application will panic when GetToken() is called on nil
return nil, errors.New("sources cannot contain nil")
}
if mc, ok := source.(*ManagedIdentityCredential); ok {
mc.mic.chained = true
}
}
cp := make([]azcore.TokenCredential, len(sources))
copy(cp, sources)
Expand Down
117 changes: 96 additions & 21 deletions sdk/azidentity/default_azure_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -322,7 +321,6 @@ func TestDefaultAzureCredential_IMDS(t *testing.T) {
probed = true
require.Empty(t, hdr, "probe request shouldn't have Metadata header")
return &http.Response{
Body: io.NopCloser(strings.NewReader("{}")),
StatusCode: http.StatusInternalServerError,
}
},
Expand All @@ -335,27 +333,26 @@ func TestDefaultAzureCredential_IMDS(t *testing.T) {
require.True(t, probed)
require.Equal(t, tokenValue, tk.Token)

t.Run("non-JSON response", func(t *testing.T) {
t.Run("Azure Container Instances", func(t *testing.T) {
// ensure GetToken returns an error if DefaultAzureCredential skips managed identity
before := defaultAzTokenProvider
defer func() { defaultAzTokenProvider = before }()
defaultAzTokenProvider = mockAzTokenProviderSuccess
for _, res := range [][]mock.ResponseOption{
{mock.WithStatusCode(http.StatusNotFound)},
{mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusBadRequest)},
{mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusOK)},
} {
srv, close := mock.NewTLSServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.SetResponse(res...)
cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{
Transport: srv,
},
})
require.NoError(t, err)
_, err = cred.GetToken(ctx, testTRO)
require.NoError(t, err, "DefaultAzureCredential should continue after receiving a non-JSON response from IMDS")
}
defaultAzTokenProvider = mockAzTokenProviderFailure

srv, close := mock.NewTLSServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(mock.WithBody([]byte("Required metadata header not specified or not correct")), mock.WithStatusCode(http.StatusBadRequest))
srv.AppendResponse(mock.WithBody(accessTokenRespSuccess), mock.WithStatusCode(http.StatusOK))

cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{
Transport: srv,
},
})
require.NoError(t, err)
tk, err := cred.GetToken(ctx, testTRO)
require.NoError(t, err, "DefaultAzureCredential should accept ACI's response to the probe request")
require.Equal(t, tokenValue, tk.Token)
})
})

Expand Down Expand Up @@ -397,6 +394,84 @@ func TestDefaultAzureCredential_IMDS(t *testing.T) {
})
}

func TestDefaultAzureCredential_UnexpectedIMDSResponse(t *testing.T) {
before := defaultAzTokenProvider
defer func() { defaultAzTokenProvider = before }()
defaultAzTokenProvider = mockAzTokenProviderSuccess

const dockerDesktopPrefix = "connecting to 169.254.169.254:80: connecting to 169.254.169.254:80: dial tcp 169.254.169.254:80: connectex: A socket operation was attempted to an unreachable "
for _, test := range []struct {
desc string
res [][]mock.ResponseOption
}{
{
"Docker Desktop",
[][]mock.ResponseOption{
{
mock.WithBody([]byte(dockerDesktopPrefix + "host.")),
mock.WithStatusCode(http.StatusForbidden),
},
{
mock.WithBody([]byte(dockerDesktopPrefix + "host.")),
mock.WithStatusCode(http.StatusForbidden),
},
},
},
{
"Docker Desktop",
[][]mock.ResponseOption{
{
mock.WithBody([]byte(dockerDesktopPrefix + "network.")),
mock.WithStatusCode(http.StatusForbidden),
},
{
mock.WithBody([]byte(dockerDesktopPrefix + "network.")),
mock.WithStatusCode(http.StatusForbidden),
},
},
},
{
"IMDS: no identity assigned",
[][]mock.ResponseOption{
{mock.WithStatusCode(http.StatusBadRequest)},
{
mock.WithBody([]byte(`{"error":"invalid_request","error_description":"Identity not found"}`)),
mock.WithStatusCode(http.StatusBadRequest),
},
},
},
{
"no token in response",
[][]mock.ResponseOption{
{mock.WithStatusCode(http.StatusOK)},
{mock.WithBody([]byte(`{"error": "no token here"}`)), mock.WithStatusCode(http.StatusOK)},
},
},
{
"non-JSON token response",
[][]mock.ResponseOption{
{mock.WithStatusCode(http.StatusOK)},
{mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusOK)},
},
},
} {
t.Run(test.desc, func(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
for _, res := range test.res {
srv.AppendResponse(res...)
}
c, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{Transport: srv},
})
require.NoError(t, err)
tk, err := c.GetToken(ctx, testTRO)
require.NoError(t, err, "expected a token from AzureCLICredential")
require.Equal(t, tokenValue, tk.Token, "expected a token from AzureCLICredential")
})
}
}

func TestDefaultAzureCredential_UnsupportedMIClientID(t *testing.T) {
fail := true
before := defaultAzTokenProvider
Expand Down
37 changes: 23 additions & 14 deletions sdk/azidentity/managed_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ type managedIdentityClient struct {
id ManagedIDKind
msiType msiType
probeIMDS bool
// chained indicates whether the client is part of a credential chain. If true, the client will return
// a credentialUnavailableError instead of an AuthenticationFailedError for an unexpected IMDS response.
chained bool
}

// arcKeyDirectory returns the directory expected to contain Azure Arc keys
Expand Down Expand Up @@ -215,31 +218,22 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
// no need to synchronize around this value because it's true only when DefaultAzureCredential constructed the client,
// and in that case ChainedTokenCredential.GetToken synchronizes goroutines that would execute this block
if c.probeIMDS {
// send a malformed request (no Metadata header) to IMDS to determine whether the endpoint is available
cx, cancel := context.WithTimeout(ctx, imdsProbeTimeout)
defer cancel()
cx = policy.WithRetryOptions(cx, policy.RetryOptions{MaxRetries: -1})
req, err := azruntime.NewRequest(cx, http.MethodGet, c.endpoint)
if err != nil {
return azcore.AccessToken{}, fmt.Errorf("failed to create IMDS probe request: %s", err)
}
res, err := c.azClient.Pipeline().Do(req)
if err != nil {
if _, err = c.azClient.Pipeline().Do(req); err != nil {
msg := err.Error()
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
msg = "managed identity timed out. See https://aka.ms/azsdk/go/identity/troubleshoot#dac for more information"
}
return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, msg)
}
// because IMDS always responds with JSON, assume a non-JSON response is from something else, such
// as a proxy, and return credentialUnavailableError so DefaultAzureCredential continues iterating
b, err := azruntime.Payload(res)
if err != nil {
return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, fmt.Sprintf("failed to read IMDS probe response: %s", err))
}
if !json.Valid(b) {
return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, "unexpected response to IMDS probe")
}
// send normal token requests from now on because IMDS responded
// send normal token requests from now on because something responded
c.probeIMDS = false
}

Expand All @@ -254,13 +248,21 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
}

if azruntime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) {
return c.createAccessToken(resp)
tk, err := c.createAccessToken(resp)
if err != nil && c.chained && c.msiType == msiTypeIMDS {
// failure to unmarshal a 2xx implies the response is from something other than IMDS such as a proxy listening at
// the same address. Return a credentialUnavailableError so credential chains continue to their next credential
err = newCredentialUnavailableError(credNameManagedIdentity, err.Error())
}
return tk, err
}

if c.msiType == msiTypeIMDS {
switch resp.StatusCode {
case http.StatusBadRequest:
if id != nil {
// return authenticationFailedError, halting any encompassing credential chain,
// because the explicit user-assigned identity implies the developer expected this to work
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "the requested identity isn't assigned to this resource", resp)
}
msg := "failed to authenticate a system assigned identity"
Expand All @@ -276,6 +278,13 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, fmt.Sprintf("unexpected response %q", string(body)))
}
}
if c.chained {
// the response may be from something other than IMDS, for example a proxy returning
// 404. Return credentialUnavailableError so credential chains continue to their
// next credential, include the response in the error message to help debugging
err = newAuthenticationFailedError(credNameManagedIdentity, "", resp)
return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, err.Error())
}
}

return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "", resp)
Expand All @@ -290,7 +299,7 @@ func (c *managedIdentityClient) createAccessToken(res *http.Response) (azcore.Ac
ExpiresOn interface{} `json:"expires_on,omitempty"` // the value returned in this field varies between a number and a date string
}{}
if err := azruntime.UnmarshalAsJSON(res, &value); err != nil {
return azcore.AccessToken{}, fmt.Errorf("internal AccessToken: %v", err)
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "Unexpected response content", res)
}
if value.ExpiresIn != "" {
expiresIn, err := json.Number(value.ExpiresIn).Int64()
Expand Down
49 changes: 0 additions & 49 deletions sdk/azidentity/managed_identity_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@ package azidentity

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
)

type userAgentValidatingPolicy struct {
Expand Down Expand Up @@ -76,49 +73,3 @@ func TestManagedIdentityClient_ApplicationID(t *testing.T) {
t.Fatal(err)
}
}

func TestManagedIdentityClient_IMDSErrors(t *testing.T) {
for _, test := range []struct {
body, desc string
code int
}{
{
desc: "No identity assigned",
code: http.StatusBadRequest,
body: `{"error":"invalid_request","error_description":"Identity not found"}`,
},
{
desc: "Docker Desktop",
code: http.StatusForbidden,
body: "connecting to 169.254.169.254:80: connecting to 169.254.169.254:80: dial tcp 169.254.169.254:80: connectex: A socket operation was attempted to an unreachable network.",
},
{
desc: "Docker Desktop",
code: http.StatusForbidden,
body: "connecting to 169.254.169.254:80: connecting to 169.254.169.254:80: dial tcp 169.254.169.254:80: connectex: A socket operation was attempted to an unreachable host.",
},
} {
t.Run(fmt.Sprint(test.code), func(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.SetResponse(mock.WithBody([]byte(test.body)), mock.WithStatusCode(test.code))
client, err := newManagedIdentityClient(&ManagedIdentityCredentialOptions{
ClientOptions: azcore.ClientOptions{Transport: srv},
})
if err != nil {
t.Fatal(err)
}
_, err = client.authenticate(context.Background(), nil, testTRO.Scopes)
if err == nil {
t.Fatal("expected an error")
}
if actual := err.Error(); !strings.Contains(actual, test.body) {
t.Fatalf("expected response body in error, got %q", actual)
}
var unavailableErr credentialUnavailable
if !errors.As(err, &unavailableErr) {
t.Fatalf("expected %T, got %T", unavailableErr, err)
}
})
}
}
29 changes: 29 additions & 0 deletions sdk/azidentity/managed_identity_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,35 @@ func TestManagedIdentityCredential_IMDSRetries(t *testing.T) {
}
}

func TestManagedIdentityCredential_UnexpectedIMDSResponse(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
tests := [][]mock.ResponseOption{
{mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusOK)},
}
// credential should return AuthenticationFailedError when a token request ends with a retriable response
ro := policy.RetryOptions{}
setIMDSRetryOptionDefaults(&ro)
for _, c := range ro.StatusCodes {
tests = append(tests, []mock.ResponseOption{mock.WithStatusCode(c)})
}
for _, res := range tests {
srv.AppendResponse(res...)

c, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{
ClientOptions: policy.ClientOptions{
Retry: policy.RetryOptions{MaxRetries: -1},
Transport: srv,
},
})
require.NoError(t, err)

_, err = c.GetToken(ctx, testTRO)
var af *AuthenticationFailedError
require.ErrorAs(t, err, &af, "unexpected token response from IMDS should prompt an AuthenticationFailedError")
}
}

func TestManagedIdentityCredential_ServiceFabric(t *testing.T) {
expectedSecret := "expected-secret"
pred := func(req *http.Request) bool {
Expand Down
Loading

0 comments on commit 786b0be

Please sign in to comment.