Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a policy to reauthorize the credential when invalid #2887

Merged
merged 13 commits into from
Jan 22, 2025
39 changes: 30 additions & 9 deletions cmd/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,8 @@ func (cca *CookedCopyCmdArgs) processRedirectionDownload(blobResource common.Res
}

// step 1: create client options
options := &blockblob.ClientOptions{ClientOptions: createClientOptions(azcopyScanningLogger, nil)}
// note: dstCred is nil, as we could not reauth effectively because stdout is a pipe.
options := &blockblob.ClientOptions{ClientOptions: createClientOptions(azcopyScanningLogger, nil, nil)}

// step 2: parse source url
u, err := blobResource.FullURL()
Expand Down Expand Up @@ -1386,8 +1387,15 @@ func (cca *CookedCopyCmdArgs) processRedirectionUpload(blobResource common.Resou
return fmt.Errorf("fatal: cannot find auth on destination blob URL: %s", err.Error())
}

var reauthTok *common.ScopedAuthenticator
if at, ok := credInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok {
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
}

// step 0: initialize pipeline
options := &blockblob.ClientOptions{ClientOptions: createClientOptions(common.AzcopyCurrentJobLogger, nil)}
// Reauthentication is theoretically possible here, since stdin is blocked.
options := &blockblob.ClientOptions{ClientOptions: createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok)}

// step 1: parse destination url
u, err := blobResource.FullURL()
Expand Down Expand Up @@ -1569,18 +1577,25 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) {
},
}

options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
srcCredInfo, err := cca.getSrcCredential(ctx, &jobPartOrder)
if err != nil {
return err
}

var srcReauth *common.ScopedAuthenticator
if at, ok := srcCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok {
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
srcReauth = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
}

options := createClientOptions(common.AzcopyCurrentJobLogger, nil, srcReauth)
var azureFileSpecificOptions any
if cca.FromTo.From() == common.ELocation.File() {
azureFileSpecificOptions = &common.FileClientOptions{
AllowTrailingDot: cca.trailingDot.IsEnabled(),
}
}

srcCredInfo, err := cca.getSrcCredential(ctx, &jobPartOrder)
if err != nil {
return err
}
jobPartOrder.SrcServiceClient, err = common.GetServiceClientForLocation(
cca.FromTo.From(),
cca.Source,
Expand All @@ -1600,11 +1615,17 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) {
}
}

var srcCred *common.ScopedCredential
var dstReauthTok *common.ScopedAuthenticator
if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok {
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
dstReauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
}

var srcCred *common.ScopedToken
if cca.FromTo.IsS2S() && srcCredInfo.CredentialType.IsAzureOAuth() {
srcCred = common.NewScopedCredential(srcCredInfo.OAuthTokenInfo.TokenCredential, srcCredInfo.CredentialType)
}
options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred)
options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred, dstReauthTok)
jobPartOrder.DstServiceClient, err = common.GetServiceClientForLocation(
cca.FromTo.To(),
cca.Destination,
Expand Down
10 changes: 8 additions & 2 deletions cmd/copyEnumeratorInit.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ func (cca *CookedCopyCmdArgs) initEnumerator(jobPartOrder common.CopyJobPartOrde
return NewCopyEnumerator(traverser, filters, processor, finalizer), nil
}

// This is condensed down into an individual function as we don't end up re-using the destination traverser at all.
// This is condensed down into an individual function as we don't end up reusing the destination traverser at all.
// This is just for the directory check.
func (cca *CookedCopyCmdArgs) isDestDirectory(dst common.ResourceString, ctx *context.Context) bool {
var err error
Expand Down Expand Up @@ -428,7 +428,13 @@ func (cca *CookedCopyCmdArgs) createDstContainer(containerName string, dstWithSA
return err
}

options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
var reauthTok *common.ScopedAuthenticator
if at, ok := dstCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok {
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
}

options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok)

sc, err := common.GetServiceClientForLocation(
cca.FromTo.To(),
Expand Down
8 changes: 4 additions & 4 deletions cmd/credentialUtil.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ func isPublic(ctx context.Context, blobResourceURL string, cpkOptions common.Cpk
MaxRetryDelay: ste.UploadMaxRetryDelay,
}, policy.TelemetryOptions{
ApplicationID: common.AddUserAgentPrefix(common.UserAgent),
}, nil, ste.LogOptions{}, nil)
}, nil, ste.LogOptions{}, nil, nil)

blobClient, _ := blob.NewClientWithNoCredential(bURLParts.String(), &blob.ClientOptions{ClientOptions: clientOptions})
bURLParts.BlobName = ""
Expand Down Expand Up @@ -398,7 +398,7 @@ func mdAccountNeedsOAuth(ctx context.Context, blobResourceURL string, cpkOptions
MaxRetryDelay: ste.UploadMaxRetryDelay,
}, policy.TelemetryOptions{
ApplicationID: common.AddUserAgentPrefix(common.UserAgent),
}, nil, ste.LogOptions{}, nil)
}, nil, ste.LogOptions{}, nil, nil)

blobClient, _ := blob.NewClientWithNoCredential(blobResourceURL, &blob.ClientOptions{ClientOptions: clientOptions})
_, err := blobClient.GetProperties(ctx, &blob.GetPropertiesOptions{CPKInfo: cpkOptions.GetCPKInfo()})
Expand Down Expand Up @@ -577,7 +577,7 @@ func getCredentialType(ctx context.Context, raw rawFromToInfo, cpkOptions common
// createClientOptions creates generic client options which are required to create any
// client to interact with storage service. Default options are modified to suit azcopy.
// srcCred is required in cases where source is authenticated via oAuth for S2S transfers
func createClientOptions(logger common.ILoggerResetable, srcCred *common.ScopedCredential) azcore.ClientOptions {
func createClientOptions(logger common.ILoggerResetable, srcCred *common.ScopedToken, reauthCred *common.ScopedAuthenticator) azcore.ClientOptions {
logOptions := ste.LogOptions{}

if logger != nil {
Expand All @@ -592,7 +592,7 @@ func createClientOptions(logger common.ILoggerResetable, srcCred *common.ScopedC
MaxRetryDelay: ste.UploadMaxRetryDelay,
}, policy.TelemetryOptions{
ApplicationID: common.AddUserAgentPrefix(common.UserAgent),
}, ste.NewAzcopyHTTPClient(frontEndMaxIdleConnectionsPerHost), logOptions, srcCred)
}, ste.NewAzcopyHTTPClient(frontEndMaxIdleConnectionsPerHost), logOptions, srcCred, reauthCred)
}

const frontEndMaxIdleConnectionsPerHost = http.DefaultMaxIdleConnsPerHost
12 changes: 9 additions & 3 deletions cmd/jobsResume.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,13 +299,19 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients(
}
}

options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
var reauthTok *common.ScopedAuthenticator
if at, ok := tc.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way.
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
}
jobID, err := common.ParseJobID(rca.jobID)
if err != nil {
// Error for invalid JobId format
return nil, nil, fmt.Errorf("error parsing the jobId %s. Failed with error %s", rca.jobID, err.Error())
}

// But we don't want to supply a reauth token if we're not using OAuth. That could cause problems if say, a SAS is invalid.
options := createClientOptions(common.AzcopyCurrentJobLogger, nil, common.Iff(srcCredType.IsAzureOAuth(), reauthTok, nil))
var getJobDetailsResponse common.GetJobDetailsResponse
// Get job details from the STE
Rpc(common.ERpcCmd.GetJobDetails(),
Expand All @@ -326,11 +332,11 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients(
return nil, nil, err
}

var srcCred *common.ScopedCredential
var srcCred *common.ScopedToken
if fromTo.IsS2S() && srcCredType.IsAzureOAuth() {
srcCred = common.NewScopedCredential(tc, srcCredType)
}
options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred)
options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred, common.Iff(dstCredType.IsAzureOAuth(), reauthTok, nil))
var fileClientOptions any
if fromTo.To() == common.ELocation.File() {
fileClientOptions = &common.FileClientOptions{
Expand Down
8 changes: 7 additions & 1 deletion cmd/make.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,14 @@ func (cookedArgs cookedMakeCmdArgs) process() (err error) {
return err
}

var reauthTok *common.ScopedAuthenticator
if at, ok := credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way.
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
}

// Note : trailing dot is only applicable to file operations anyway, so setting this to false
options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok)
resourceURL := cookedArgs.resourceURL.String()
cred := credentialInfo.OAuthTokenInfo.TokenCredential

Expand Down
17 changes: 15 additions & 2 deletions cmd/removeEnumerator.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,14 @@ func newRemoveEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator, er
if !from.SupportsTrailingDot() {
cca.trailingDot = common.ETrailingDotOption.Disable()
}
options := createClientOptions(common.AzcopyCurrentJobLogger, nil)

var reauthTok *common.ScopedAuthenticator
if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way.
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
}

options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok)
var fileClientOptions any
if cca.FromTo.From() == common.ELocation.File() {
fileClientOptions = &common.FileClientOptions{AllowTrailingDot: cca.trailingDot.IsEnabled()}
Expand Down Expand Up @@ -142,7 +149,13 @@ func newRemoveEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator, er
func removeBfsResources(cca *CookedCopyCmdArgs) (err error) {
ctx := context.WithValue(context.Background(), ste.ServiceAPIVersionOverride, ste.DefaultServiceApiVersion)
sourceURL, _ := cca.Source.String()
options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
var reauthTok *common.ScopedAuthenticator
if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way.
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
}

options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok)

targetServiceClient, err := common.GetServiceClientForLocation(cca.FromTo.From(), cca.Source, cca.credentialInfo.CredentialType, cca.credentialInfo.OAuthTokenInfo.TokenCredential, &options, nil)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func beginDetectNewVersion() chan struct{} {
PrintOlderVersion(*cachedVersion, *localVersion)
} else {
// step 2: initialize pipeline
options := createClientOptions(nil, nil)
options := createClientOptions(nil, nil, nil)

// step 3: start download
blobClient, err := blob.NewClientWithNoCredential(versionMetadataUrl, &blob.ClientOptions{ClientOptions: options})
Expand Down
8 changes: 7 additions & 1 deletion cmd/setPropertiesEnumerator.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,13 @@ func setPropertiesEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator
jobsAdmin.JobsAdmin.LogToJobLog(message, common.LogInfo)
}

options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
var reauthTok *common.ScopedAuthenticator
if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way.
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
}

options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok)
var fileClientOptions any
if cca.FromTo.From() == common.ELocation.File() {
fileClientOptions = &common.FileClientOptions{AllowTrailingDot: cca.trailingDot.IsEnabled()}
Expand Down
18 changes: 15 additions & 3 deletions cmd/syncEnumerator.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,13 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s
},
}

options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
var srcReauthTok *common.ScopedAuthenticator
if at, ok := srcCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok {
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
srcReauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
}

options := createClientOptions(common.AzcopyCurrentJobLogger, nil, srcReauthTok)

// Create Source Client.
var azureFileSpecificOptions any
Expand Down Expand Up @@ -209,12 +215,18 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s
}
}

var srcTokenCred *common.ScopedCredential
var dstReauthTok *common.ScopedAuthenticator
if at, ok := srcCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok {
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
dstReauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
}

var srcTokenCred *common.ScopedToken
if cca.fromTo.IsS2S() && srcCredInfo.CredentialType.IsAzureOAuth() {
srcTokenCred = common.NewScopedCredential(srcCredInfo.OAuthTokenInfo.TokenCredential, srcCredInfo.CredentialType)
}

options = createClientOptions(common.AzcopyCurrentJobLogger, srcTokenCred)
options = createClientOptions(common.AzcopyCurrentJobLogger, srcTokenCred, dstReauthTok)
copyJobTemplate.DstServiceClient, err = common.GetServiceClientForLocation(
cca.fromTo.To(),
cca.destination,
Expand Down
2 changes: 1 addition & 1 deletion cmd/versionChecker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func TestCheckReleaseMetadata(t *testing.T) {
a := assert.New(t)

// sanity test for checking if the release metadata exists and can be downloaded
options := createClientOptions(nil, nil)
options := createClientOptions(nil, nil, nil)

blobClient, err := blob.NewClientWithNoCredential(versionMetadataUrl, &blob.ClientOptions{ClientOptions: options})
a.NoError(err)
Expand Down
10 changes: 9 additions & 1 deletion cmd/zc_enumerator.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,15 @@ func InitResourceTraverser(resource common.ResourceString, location common.Locat
return output, nil
}

options := createClientOptions(azcopyScanningLogger, nil)
var reauthTok *common.ScopedAuthenticator
if credential != nil {
if at, ok := credential.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok {
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
}
}

options := createClientOptions(azcopyScanningLogger, nil, reauthTok)

switch location {
case common.ELocation.Local():
Expand Down
4 changes: 2 additions & 2 deletions cmd/zt_parseSize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestParseSize(t *testing.T) {
_, err = ParseSizeString("123T", "foo-bar") // we don't support terabytes
a.Equal(expectedError, err.Error())

_, err = ParseSizeString("abcK", "foo-bar")
_, err = ParseSizeString("abcK", "foo-bar") //codespell:ignore
a.Equal(expectedError, err.Error())

}
}
10 changes: 10 additions & 0 deletions common/oauthTokenManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,11 @@ func (credInfo *OAuthTokenInfo) GetDeviceCodeCredential() (azcore.TokenCredentia
Cloud: cloud.Configuration{ActiveDirectoryAuthorityHost: authorityHost.String()},
Transport: newAzcopyHTTPClient(),
},
UserPrompt: func(ctx context.Context, message azidentity.DeviceCodeMessage) error {
lcm.Info(fmt.Sprintf("Authentication is required. To sign in, open the webpage %s and enter the code %s to authenticate.",
Iff(message.VerificationURL != "", message.VerificationURL, "https://aka.ms/devicelogin"), message.UserCode))
return nil
},
})
if err != nil {
return nil, err
Expand All @@ -727,6 +732,11 @@ func (credInfo *OAuthTokenInfo) GetDeviceCodeCredential() (azcore.TokenCredentia
return tc, nil
}

type AuthenticateToken interface {
azcore.TokenCredential
Authenticate(ctx context.Context, opts *policy.TokenRequestOptions) (azidentity.AuthenticationRecord, error)
}

func (credInfo *OAuthTokenInfo) GetTokenCredential() (azcore.TokenCredential, error) {
// Token Credential is cached.
if credInfo.TokenCredential != nil {
Expand Down
1 change: 1 addition & 0 deletions common/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ var EPromptType = PromptType("")

type PromptType string

func (PromptType) Reauth() PromptType { return PromptType("Reauth") }
func (PromptType) Cancel() PromptType { return PromptType("Cancel") }
func (PromptType) Overwrite() PromptType { return PromptType("Overwrite") }
func (PromptType) DeleteDestination() PromptType { return PromptType("DeleteDestination") }
Expand Down
22 changes: 17 additions & 5 deletions common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"net"
"net/url"
"strings"
Expand Down Expand Up @@ -231,7 +232,7 @@ func GetServiceClientForLocation(loc Location,
// NewScopedCredential takes in a credInfo object and returns ScopedCredential
// if credentialType is either MDOAuth or oAuth. For anything else,
// nil is returned
func NewScopedCredential(cred azcore.TokenCredential, credType CredentialType) *ScopedCredential {
func NewScopedCredential[T azcore.TokenCredential](cred T, credType CredentialType) *ScopedCredential[T] {
var scope string
if !credType.IsAzureOAuth() {
return nil
Expand All @@ -240,18 +241,29 @@ func NewScopedCredential(cred azcore.TokenCredential, credType CredentialType) *
} else if credType == ECredentialType.OAuthToken() {
scope = StorageScope
}
return &ScopedCredential{cred: cred, scopes: []string{scope}}
return &ScopedCredential[T]{cred: cred, scopes: []string{scope}}
}

type ScopedCredential struct {
cred azcore.TokenCredential
type ScopedCredential[T azcore.TokenCredential] struct {
cred T
scopes []string
}

func (s *ScopedCredential) GetToken(ctx context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) {
func (s *ScopedCredential[T]) GetToken(ctx context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) {
return s.cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: s.scopes, EnableCAE: true})
}

type ScopedToken = ScopedCredential[azcore.TokenCredential]
type ScopedAuthenticator ScopedCredential[AuthenticateToken]

func (s *ScopedAuthenticator) GetToken(ctx context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) {
return s.cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: s.scopes, EnableCAE: true})
}

func (s *ScopedAuthenticator) Authenticate(ctx context.Context, _ *policy.TokenRequestOptions) (azidentity.AuthenticationRecord, error) {
return s.cred.Authenticate(ctx, &policy.TokenRequestOptions{Scopes: s.scopes, EnableCAE: true})
}

type ServiceClient struct {
fsc *fileservice.Client
bsc *blobservice.Client
Expand Down
Loading
Loading