Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Credential chains continue iterating after unexpected IMDS responses #23894

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@

### Bugs Fixed
* User credential types inconsistently log access token scopes
* `DefaultAzureCredential` skips managed identity in Azure Container Instances

### Other Changes
* `ChainedTokenCredential` and `DefaultAzureCredential` continue to their next
credential after `ManagedIdentityCredential` receives an unexpected response
from IMDS, indicating the response is from something else such as a proxy

## 1.8.0 (2024-10-08)

Expand Down
3 changes: 3 additions & 0 deletions sdk/azidentity/chained_token_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine
if source == nil { // cannot have a nil credential in the chain or else the application will panic when GetToken() is called on nil
return nil, errors.New("sources cannot contain nil")
}
if mc, ok := source.(*ManagedIdentityCredential); ok {
mc.mic.chained = true
}
}
cp := make([]azcore.TokenCredential, len(sources))
copy(cp, sources)
Expand Down
117 changes: 96 additions & 21 deletions sdk/azidentity/default_azure_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -322,7 +321,6 @@ func TestDefaultAzureCredential_IMDS(t *testing.T) {
probed = true
require.Empty(t, hdr, "probe request shouldn't have Metadata header")
return &http.Response{
Body: io.NopCloser(strings.NewReader("{}")),
StatusCode: http.StatusInternalServerError,
}
},
Expand All @@ -335,27 +333,26 @@ func TestDefaultAzureCredential_IMDS(t *testing.T) {
require.True(t, probed)
require.Equal(t, tokenValue, tk.Token)

t.Run("non-JSON response", func(t *testing.T) {
t.Run("Azure Container Instances", func(t *testing.T) {
// ensure GetToken returns an error if DefaultAzureCredential skips managed identity
before := defaultAzTokenProvider
defer func() { defaultAzTokenProvider = before }()
defaultAzTokenProvider = mockAzTokenProviderSuccess
for _, res := range [][]mock.ResponseOption{
{mock.WithStatusCode(http.StatusNotFound)},
{mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusBadRequest)},
{mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusOK)},
} {
srv, close := mock.NewTLSServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.SetResponse(res...)
cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{
Transport: srv,
},
})
require.NoError(t, err)
_, err = cred.GetToken(ctx, testTRO)
require.NoError(t, err, "DefaultAzureCredential should continue after receiving a non-JSON response from IMDS")
}
defaultAzTokenProvider = mockAzTokenProviderFailure

srv, close := mock.NewTLSServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(mock.WithBody([]byte("Required metadata header not specified or not correct")), mock.WithStatusCode(http.StatusBadRequest))
srv.AppendResponse(mock.WithBody(accessTokenRespSuccess), mock.WithStatusCode(http.StatusOK))

cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{
Transport: srv,
},
})
require.NoError(t, err)
tk, err := cred.GetToken(ctx, testTRO)
require.NoError(t, err, "DefaultAzureCredential should accept ACI's response to the probe request")
require.Equal(t, tokenValue, tk.Token)
})
})

Expand Down Expand Up @@ -397,6 +394,84 @@ func TestDefaultAzureCredential_IMDS(t *testing.T) {
})
}

func TestDefaultAzureCredential_UnexpectedIMDSResponse(t *testing.T) {
before := defaultAzTokenProvider
defer func() { defaultAzTokenProvider = before }()
defaultAzTokenProvider = mockAzTokenProviderSuccess

const dockerDesktopPrefix = "connecting to 169.254.169.254:80: connecting to 169.254.169.254:80: dial tcp 169.254.169.254:80: connectex: A socket operation was attempted to an unreachable "
for _, test := range []struct {
desc string
res [][]mock.ResponseOption
}{
{
"Docker Desktop",
[][]mock.ResponseOption{
{
mock.WithBody([]byte(dockerDesktopPrefix + "host.")),
mock.WithStatusCode(http.StatusForbidden),
},
{
mock.WithBody([]byte(dockerDesktopPrefix + "host.")),
mock.WithStatusCode(http.StatusForbidden),
},
},
},
{
"Docker Desktop",
[][]mock.ResponseOption{
{
mock.WithBody([]byte(dockerDesktopPrefix + "network.")),
mock.WithStatusCode(http.StatusForbidden),
},
{
mock.WithBody([]byte(dockerDesktopPrefix + "network.")),
mock.WithStatusCode(http.StatusForbidden),
},
},
},
{
"IMDS: no identity assigned",
[][]mock.ResponseOption{
{mock.WithStatusCode(http.StatusBadRequest)},
{
mock.WithBody([]byte(`{"error":"invalid_request","error_description":"Identity not found"}`)),
mock.WithStatusCode(http.StatusBadRequest),
},
},
},
{
"no token in response",
[][]mock.ResponseOption{
{mock.WithStatusCode(http.StatusOK)},
{mock.WithBody([]byte(`{"error": "no token here"}`)), mock.WithStatusCode(http.StatusOK)},
},
},
{
"non-JSON token response",
[][]mock.ResponseOption{
{mock.WithStatusCode(http.StatusOK)},
{mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusOK)},
},
},
} {
t.Run(test.desc, func(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
for _, res := range test.res {
srv.AppendResponse(res...)
}
c, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{Transport: srv},
})
require.NoError(t, err)
tk, err := c.GetToken(ctx, testTRO)
require.NoError(t, err, "expected a token from AzureCLICredential")
require.Equal(t, tokenValue, tk.Token, "expected a token from AzureCLICredential")
})
}
}

func TestDefaultAzureCredential_UnsupportedMIClientID(t *testing.T) {
fail := true
before := defaultAzTokenProvider
Expand Down
36 changes: 22 additions & 14 deletions sdk/azidentity/managed_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ type managedIdentityClient struct {
id ManagedIDKind
msiType msiType
probeIMDS bool
// chained indicates whether the client is part of a credential chain. If true, the client will return
// a credentialUnavailableError instead of an AuthenticationFailedError for an unexpected IMDS response.
chained bool
}

// arcKeyDirectory returns the directory expected to contain Azure Arc keys
Expand Down Expand Up @@ -215,31 +218,22 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
// no need to synchronize around this value because it's true only when DefaultAzureCredential constructed the client,
// and in that case ChainedTokenCredential.GetToken synchronizes goroutines that would execute this block
if c.probeIMDS {
// send a malformed request (no Metadata header) to IMDS to determine whether the endpoint is available
cx, cancel := context.WithTimeout(ctx, imdsProbeTimeout)
defer cancel()
cx = policy.WithRetryOptions(cx, policy.RetryOptions{MaxRetries: -1})
req, err := azruntime.NewRequest(cx, http.MethodGet, c.endpoint)
if err != nil {
return azcore.AccessToken{}, fmt.Errorf("failed to create IMDS probe request: %s", err)
}
res, err := c.azClient.Pipeline().Do(req)
if err != nil {
if _, err = c.azClient.Pipeline().Do(req); err != nil {
msg := err.Error()
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
msg = "managed identity timed out. See https://aka.ms/azsdk/go/identity/troubleshoot#dac for more information"
}
return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, msg)
}
// because IMDS always responds with JSON, assume a non-JSON response is from something else, such
// as a proxy, and return credentialUnavailableError so DefaultAzureCredential continues iterating
b, err := azruntime.Payload(res)
if err != nil {
return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, fmt.Sprintf("failed to read IMDS probe response: %s", err))
}
if !json.Valid(b) {
return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, "unexpected response to IMDS probe")
}
// send normal token requests from now on because IMDS responded
// send normal token requests from now on because someothing responded
c.probeIMDS = false
}

Expand All @@ -254,13 +248,21 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
}

if azruntime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) {
return c.createAccessToken(resp)
tk, err := c.createAccessToken(resp)
if err != nil && c.chained && c.msiType == msiTypeIMDS {
jhendrixMSFT marked this conversation as resolved.
Show resolved Hide resolved
// failure to unmarshal a 2xx implies the response is from something other than IMDS such as a proxy listening at
// the same address. Return a credentialUnavailableError so credential chains continue to their next credential
err = newCredentialUnavailableError(credNameManagedIdentity, err.Error())
}
return tk, err
}

if c.msiType == msiTypeIMDS {
switch resp.StatusCode {
case http.StatusBadRequest:
if id != nil {
// return authenticationFailedError, halting any encompassing credential chain,
// because the explicit user-assigned identity implies the developer expected this to work
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "the requested identity isn't assigned to this resource", resp)
}
msg := "failed to authenticate a system assigned identity"
Expand All @@ -275,6 +277,12 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
if err == nil && strings.Contains(string(body), "unreachable") {
return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, fmt.Sprintf("unexpected response %q", string(body)))
}
default:
// the response may be from something other than IMDS, for example a proxy returning
// 404. Return credentialUnavailableError so credential chains continue to their
// next credential, include the response in the error message to help debugging
err = newAuthenticationFailedError(credNameManagedIdentity, "", resp)
return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, err.Error())
}
}

Expand All @@ -290,7 +298,7 @@ func (c *managedIdentityClient) createAccessToken(res *http.Response) (azcore.Ac
ExpiresOn interface{} `json:"expires_on,omitempty"` // the value returned in this field varies between a number and a date string
}{}
if err := azruntime.UnmarshalAsJSON(res, &value); err != nil {
return azcore.AccessToken{}, fmt.Errorf("internal AccessToken: %v", err)
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "Unexpected response content", res)
}
if value.ExpiresIn != "" {
expiresIn, err := json.Number(value.ExpiresIn).Int64()
Expand Down
49 changes: 0 additions & 49 deletions sdk/azidentity/managed_identity_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@ package azidentity

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
)

type userAgentValidatingPolicy struct {
Expand Down Expand Up @@ -76,49 +73,3 @@ func TestManagedIdentityClient_ApplicationID(t *testing.T) {
t.Fatal(err)
}
}

func TestManagedIdentityClient_IMDSErrors(t *testing.T) {
for _, test := range []struct {
body, desc string
code int
}{
{
desc: "No identity assigned",
code: http.StatusBadRequest,
body: `{"error":"invalid_request","error_description":"Identity not found"}`,
},
{
desc: "Docker Desktop",
code: http.StatusForbidden,
body: "connecting to 169.254.169.254:80: connecting to 169.254.169.254:80: dial tcp 169.254.169.254:80: connectex: A socket operation was attempted to an unreachable network.",
},
{
desc: "Docker Desktop",
code: http.StatusForbidden,
body: "connecting to 169.254.169.254:80: connecting to 169.254.169.254:80: dial tcp 169.254.169.254:80: connectex: A socket operation was attempted to an unreachable host.",
},
} {
t.Run(fmt.Sprint(test.code), func(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.SetResponse(mock.WithBody([]byte(test.body)), mock.WithStatusCode(test.code))
client, err := newManagedIdentityClient(&ManagedIdentityCredentialOptions{
ClientOptions: azcore.ClientOptions{Transport: srv},
})
if err != nil {
t.Fatal(err)
}
_, err = client.authenticate(context.Background(), nil, testTRO.Scopes)
if err == nil {
t.Fatal("expected an error")
}
if actual := err.Error(); !strings.Contains(actual, test.body) {
t.Fatalf("expected response body in error, got %q", actual)
}
var unavailableErr credentialUnavailable
if !errors.As(err, &unavailableErr) {
t.Fatalf("expected %T, got %T", unavailableErr, err)
}
})
}
}
15 changes: 15 additions & 0 deletions sdk/azidentity/managed_identity_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,21 @@ func TestManagedIdentityCredential_IMDSRetries(t *testing.T) {
}
}

func TestManagedIdentityCredential_UnexpectedIMDSResponse(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusOK))

c, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{
ClientOptions: policy.ClientOptions{Transport: srv},
})
require.NoError(t, err)

_, err = c.GetToken(ctx, testTRO)
var authFailed *AuthenticationFailedError
require.ErrorAs(t, err, &authFailed, "unexpected token response from IMDS should prompt an AuthenticationFailedError")
}

func TestManagedIdentityCredential_ServiceFabric(t *testing.T) {
expectedSecret := "expected-secret"
pred := func(req *http.Request) bool {
Expand Down
Loading
Loading