From 7168bbec4d0d5ec9d0f18143cb782455cac6b491 Mon Sep 17 00:00:00 2001 From: Adele Reed Date: Thu, 5 Dec 2024 18:49:02 -0800 Subject: [PATCH 01/10] Implement a policy to reauthorize the credential when invalid --- cmd/copy.go | 39 +++++-- cmd/copyEnumeratorInit.go | 8 +- cmd/credentialUtil.go | 8 +- cmd/jobsResume.go | 13 ++- cmd/make.go | 8 +- cmd/removeEnumerator.go | 17 +++- cmd/root.go | 2 +- cmd/setPropertiesEnumerator.go | 8 +- cmd/syncEnumerator.go | 18 +++- cmd/versionChecker_test.go | 2 +- cmd/zc_enumerator.go | 8 +- common/oauthTokenManager.go | 10 ++ common/output.go | 1 + common/util.go | 22 +++- ste/destReauthPolicy.go | 114 +++++++++++++++++++++ ste/mgr-JobPartMgr.go | 5 +- ste/testJobPartTransferManager_test.go | 2 +- ste/zt_destReauthPolicy_test.go | 135 +++++++++++++++++++++++++ 18 files changed, 386 insertions(+), 34 deletions(-) create mode 100644 ste/destReauthPolicy.go create mode 100644 ste/zt_destReauthPolicy_test.go diff --git a/cmd/copy.go b/cmd/copy.go index c7a87d54d..2a23993e4 100644 --- a/cmd/copy.go +++ b/cmd/copy.go @@ -1305,7 +1305,8 @@ func (cca *CookedCopyCmdArgs) processRedirectionDownload(blobResource common.Res } // step 1: create client options - options := &blockblob.ClientOptions{ClientOptions: createClientOptions(azcopyScanningLogger, nil)} + // note: dstCred is nil, as we could not reauth effectively because stdout is a pipe. + options := &blockblob.ClientOptions{ClientOptions: createClientOptions(azcopyScanningLogger, nil, nil)} // step 2: parse source url u, err := blobResource.FullURL() @@ -1381,8 +1382,15 @@ func (cca *CookedCopyCmdArgs) processRedirectionUpload(blobResource common.Resou return fmt.Errorf("fatal: cannot find auth on destination blob URL: %s", err.Error()) } + var reauthTok *common.ScopedAuthenticator + if at, ok := credInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + // step 0: initialize pipeline - options := &blockblob.ClientOptions{ClientOptions: createClientOptions(common.AzcopyCurrentJobLogger, nil)} + // Reauthentication is theoretically possible here, since stdin is blocked. + options := &blockblob.ClientOptions{ClientOptions: createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok)} // step 1: parse destination url u, err := blobResource.FullURL() @@ -1564,7 +1572,18 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) { }, } - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + srcCredInfo, err := cca.getSrcCredential(ctx, &jobPartOrder) + if err != nil { + return err + } + + var srcReauth *common.ScopedAuthenticator + if at, ok := srcCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + srcReauth = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, srcReauth) var azureFileSpecificOptions any if cca.FromTo.From() == common.ELocation.File() { azureFileSpecificOptions = &common.FileClientOptions{ @@ -1572,10 +1591,6 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) { } } - srcCredInfo, err := cca.getSrcCredential(ctx, &jobPartOrder) - if err != nil { - return err - } jobPartOrder.SrcServiceClient, err = common.GetServiceClientForLocation( cca.FromTo.From(), cca.Source, @@ -1595,11 +1610,17 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) { } } - var srcCred *common.ScopedCredential + var dstReauthTok *common.ScopedAuthenticator + if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + dstReauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + var srcCred *common.ScopedToken if cca.FromTo.IsS2S() && srcCredInfo.CredentialType.IsAzureOAuth() { srcCred = common.NewScopedCredential(srcCredInfo.OAuthTokenInfo.TokenCredential, srcCredInfo.CredentialType) } - options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred) + options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred, dstReauthTok) jobPartOrder.DstServiceClient, err = common.GetServiceClientForLocation( cca.FromTo.To(), cca.Destination, diff --git a/cmd/copyEnumeratorInit.go b/cmd/copyEnumeratorInit.go index 73c537a34..b7ef9cad7 100755 --- a/cmd/copyEnumeratorInit.go +++ b/cmd/copyEnumeratorInit.go @@ -428,7 +428,13 @@ func (cca *CookedCopyCmdArgs) createDstContainer(containerName string, dstWithSA return err } - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + var reauthTok *common.ScopedAuthenticator + if at, ok := dstCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok) sc, err := common.GetServiceClientForLocation( cca.FromTo.To(), diff --git a/cmd/credentialUtil.go b/cmd/credentialUtil.go index 2137a2dd2..2ed51703c 100644 --- a/cmd/credentialUtil.go +++ b/cmd/credentialUtil.go @@ -365,7 +365,7 @@ func isPublic(ctx context.Context, blobResourceURL string, cpkOptions common.Cpk MaxRetryDelay: ste.UploadMaxRetryDelay, }, policy.TelemetryOptions{ ApplicationID: glcm.AddUserAgentPrefix(common.UserAgent), - }, nil, ste.LogOptions{}, nil) + }, nil, ste.LogOptions{}, nil, nil) blobClient, _ := blob.NewClientWithNoCredential(bURLParts.String(), &blob.ClientOptions{ClientOptions: clientOptions}) bURLParts.BlobName = "" @@ -398,7 +398,7 @@ func mdAccountNeedsOAuth(ctx context.Context, blobResourceURL string, cpkOptions MaxRetryDelay: ste.UploadMaxRetryDelay, }, policy.TelemetryOptions{ ApplicationID: glcm.AddUserAgentPrefix(common.UserAgent), - }, nil, ste.LogOptions{}, nil) + }, nil, ste.LogOptions{}, nil, nil) blobClient, _ := blob.NewClientWithNoCredential(blobResourceURL, &blob.ClientOptions{ClientOptions: clientOptions}) _, err := blobClient.GetProperties(ctx, &blob.GetPropertiesOptions{CPKInfo: cpkOptions.GetCPKInfo()}) @@ -577,7 +577,7 @@ func getCredentialType(ctx context.Context, raw rawFromToInfo, cpkOptions common // createClientOptions creates generic client options which are required to create any // client to interact with storage service. Default options are modified to suit azcopy. // srcCred is required in cases where source is authenticated via oAuth for S2S transfers -func createClientOptions(logger common.ILoggerResetable, srcCred *common.ScopedCredential) azcore.ClientOptions { +func createClientOptions(logger common.ILoggerResetable, srcCred *common.ScopedToken, reauthCred *common.ScopedAuthenticator) azcore.ClientOptions { logOptions := ste.LogOptions{} if logger != nil { @@ -592,7 +592,7 @@ func createClientOptions(logger common.ILoggerResetable, srcCred *common.ScopedC MaxRetryDelay: ste.UploadMaxRetryDelay, }, policy.TelemetryOptions{ ApplicationID: glcm.AddUserAgentPrefix(common.UserAgent), - }, ste.NewAzcopyHTTPClient(frontEndMaxIdleConnectionsPerHost), logOptions, srcCred) + }, ste.NewAzcopyHTTPClient(frontEndMaxIdleConnectionsPerHost), logOptions, srcCred, reauthCred) } const frontEndMaxIdleConnectionsPerHost = http.DefaultMaxIdleConnsPerHost diff --git a/cmd/jobsResume.go b/cmd/jobsResume.go index 053e6e319..03f0abcb5 100644 --- a/cmd/jobsResume.go +++ b/cmd/jobsResume.go @@ -294,18 +294,25 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients( } } - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + var reauthTok *common.ScopedAuthenticator + if at, ok := tc.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way. + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + // But we don't want to supply a reauth token if we're not using OAuth. That could cause problems if say, a SAS is invalid. + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, common.Iff(srcCredType.IsAzureOAuth(), reauthTok, nil)) srcServiceClient, err := common.GetServiceClientForLocation(fromTo.From(), source, srcCredType, tc, &options, nil) if err != nil { return nil, nil, err } - var srcCred *common.ScopedCredential + var srcCred *common.ScopedToken if fromTo.IsS2S() && srcCredType.IsAzureOAuth() { srcCred = common.NewScopedCredential(tc, srcCredType) } - options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred) + options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred, common.Iff(dstCredType.IsAzureOAuth(), reauthTok, nil)) dstServiceClient, err := common.GetServiceClientForLocation(fromTo.To(), destination, dstCredType, tc, &options, nil) if err != nil { return nil, nil, err diff --git a/cmd/make.go b/cmd/make.go index 13c900c20..4214461c1 100644 --- a/cmd/make.go +++ b/cmd/make.go @@ -90,8 +90,14 @@ func (cookedArgs cookedMakeCmdArgs) process() (err error) { return err } + var reauthTok *common.ScopedAuthenticator + if at, ok := credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way. + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + // Note : trailing dot is only applicable to file operations anyway, so setting this to false - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok) resourceURL := cookedArgs.resourceURL.String() cred := credentialInfo.OAuthTokenInfo.TokenCredential diff --git a/cmd/removeEnumerator.go b/cmd/removeEnumerator.go index e7ce7213a..b9abd6ef5 100755 --- a/cmd/removeEnumerator.go +++ b/cmd/removeEnumerator.go @@ -90,7 +90,14 @@ func newRemoveEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator, er if !from.SupportsTrailingDot() { cca.trailingDot = common.ETrailingDotOption.Disable() } - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + + var reauthTok *common.ScopedAuthenticator + if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way. + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok) var fileClientOptions any if cca.FromTo.From() == common.ELocation.File() { fileClientOptions = &common.FileClientOptions{AllowTrailingDot: cca.trailingDot == common.ETrailingDotOption.Enable()} @@ -142,7 +149,13 @@ func newRemoveEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator, er func removeBfsResources(cca *CookedCopyCmdArgs) (err error) { ctx := context.WithValue(context.Background(), ste.ServiceAPIVersionOverride, ste.DefaultServiceApiVersion) sourceURL, _ := cca.Source.String() - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + var reauthTok *common.ScopedAuthenticator + if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way. + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok) targetServiceClient, err := common.GetServiceClientForLocation(cca.FromTo.From(), cca.Source, cca.credentialInfo.CredentialType, cca.credentialInfo.OAuthTokenInfo.TokenCredential, &options, nil) if err != nil { diff --git a/cmd/root.go b/cmd/root.go index 2b53212c1..79278d15f 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -310,7 +310,7 @@ func beginDetectNewVersion() chan struct{} { PrintOlderVersion(*cachedVersion, *localVersion) } else { // step 2: initialize pipeline - options := createClientOptions(nil, nil) + options := createClientOptions(nil, nil, nil) // step 3: start download blobClient, err := blob.NewClientWithNoCredential(versionMetadataUrl, &blob.ClientOptions{ClientOptions: options}) diff --git a/cmd/setPropertiesEnumerator.go b/cmd/setPropertiesEnumerator.go index 14e9be84e..7235a7a49 100755 --- a/cmd/setPropertiesEnumerator.go +++ b/cmd/setPropertiesEnumerator.go @@ -72,7 +72,13 @@ func setPropertiesEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator jobsAdmin.JobsAdmin.LogToJobLog(message, common.LogInfo) } - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + var reauthTok *common.ScopedAuthenticator + if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way. + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok) var fileClientOptions any if cca.FromTo.From() == common.ELocation.File() { fileClientOptions = &common.FileClientOptions{AllowTrailingDot: cca.trailingDot == common.ETrailingDotOption.Enable()} diff --git a/cmd/syncEnumerator.go b/cmd/syncEnumerator.go index 71ae2e8bf..dc244786a 100644 --- a/cmd/syncEnumerator.go +++ b/cmd/syncEnumerator.go @@ -179,7 +179,13 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s }, } - options := createClientOptions(common.AzcopyCurrentJobLogger, nil) + var srcReauthTok *common.ScopedAuthenticator + if at, ok := srcCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + srcReauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + options := createClientOptions(common.AzcopyCurrentJobLogger, nil, srcReauthTok) // Create Source Client. var azureFileSpecificOptions any @@ -209,12 +215,18 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s } } - var srcTokenCred *common.ScopedCredential + var dstReauthTok *common.ScopedAuthenticator + if at, ok := srcCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + srcReauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + var srcTokenCred *common.ScopedToken if cca.fromTo.IsS2S() && srcCredInfo.CredentialType.IsAzureOAuth() { srcTokenCred = common.NewScopedCredential(srcCredInfo.OAuthTokenInfo.TokenCredential, srcCredInfo.CredentialType) } - options = createClientOptions(common.AzcopyCurrentJobLogger, srcTokenCred) + options = createClientOptions(common.AzcopyCurrentJobLogger, srcTokenCred, dstReauthTok) copyJobTemplate.DstServiceClient, err = common.GetServiceClientForLocation( cca.fromTo.To(), cca.destination, diff --git a/cmd/versionChecker_test.go b/cmd/versionChecker_test.go index c65597de9..7ddae4f83 100644 --- a/cmd/versionChecker_test.go +++ b/cmd/versionChecker_test.go @@ -215,7 +215,7 @@ func TestCheckReleaseMetadata(t *testing.T) { a := assert.New(t) // sanity test for checking if the release metadata exists and can be downloaded - options := createClientOptions(nil, nil) + options := createClientOptions(nil, nil, nil) blobClient, err := blob.NewClientWithNoCredential(versionMetadataUrl, &blob.ClientOptions{ClientOptions: options}) a.NoError(err) diff --git a/cmd/zc_enumerator.go b/cmd/zc_enumerator.go index 3c0a1ec97..cadaa3961 100644 --- a/cmd/zc_enumerator.go +++ b/cmd/zc_enumerator.go @@ -381,7 +381,13 @@ func InitResourceTraverser(resource common.ResourceString, location common.Locat return output, nil } - options := createClientOptions(azcopyScanningLogger, nil) + var reauthTok *common.ScopedAuthenticator + if at, ok := credential.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } + + options := createClientOptions(azcopyScanningLogger, nil, reauthTok) switch location { case common.ELocation.Local(): diff --git a/common/oauthTokenManager.go b/common/oauthTokenManager.go index bd1b30181..bc4273c38 100644 --- a/common/oauthTokenManager.go +++ b/common/oauthTokenManager.go @@ -701,6 +701,11 @@ func (credInfo *OAuthTokenInfo) GetDeviceCodeCredential() (azcore.TokenCredentia Cloud: cloud.Configuration{ActiveDirectoryAuthorityHost: authorityHost.String()}, Transport: newAzcopyHTTPClient(), }, + UserPrompt: func(ctx context.Context, message azidentity.DeviceCodeMessage) error { + lcm.Info(fmt.Sprintf("Authentication is required. To sign in, open the webpage %s and enter the code %s to authenticate.", + Iff(message.VerificationURL != "", message.VerificationURL, "https://aka.ms/devicelogin"), message.UserCode)) + return nil + }, }) if err != nil { return nil, err @@ -727,6 +732,11 @@ func (credInfo *OAuthTokenInfo) GetDeviceCodeCredential() (azcore.TokenCredentia return tc, nil } +type AuthenticateToken interface { + azcore.TokenCredential + Authenticate(ctx context.Context, opts *policy.TokenRequestOptions) (azidentity.AuthenticationRecord, error) +} + func (credInfo *OAuthTokenInfo) GetTokenCredential() (azcore.TokenCredential, error) { // Token Credential is cached. if credInfo.TokenCredential != nil { diff --git a/common/output.go b/common/output.go index 910f169a5..9df415dd8 100644 --- a/common/output.go +++ b/common/output.go @@ -68,6 +68,7 @@ var EPromptType = PromptType("") type PromptType string +func (PromptType) Reauth() PromptType { return PromptType("Reauth") } func (PromptType) Cancel() PromptType { return PromptType("Cancel") } func (PromptType) Overwrite() PromptType { return PromptType("Overwrite") } func (PromptType) DeleteDestination() PromptType { return PromptType("DeleteDestination") } diff --git a/common/util.go b/common/util.go index e7708dab8..b81a95d3c 100644 --- a/common/util.go +++ b/common/util.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "net" "net/url" "strings" @@ -231,7 +232,7 @@ func GetServiceClientForLocation(loc Location, // NewScopedCredential takes in a credInfo object and returns ScopedCredential // if credentialType is either MDOAuth or oAuth. For anything else, // nil is returned -func NewScopedCredential(cred azcore.TokenCredential, credType CredentialType) *ScopedCredential { +func NewScopedCredential[T azcore.TokenCredential](cred T, credType CredentialType) *ScopedCredential[T] { var scope string if !credType.IsAzureOAuth() { return nil @@ -240,18 +241,29 @@ func NewScopedCredential(cred azcore.TokenCredential, credType CredentialType) * } else if credType == ECredentialType.OAuthToken() { scope = StorageScope } - return &ScopedCredential{cred: cred, scopes: []string{scope}} + return &ScopedCredential[T]{cred: cred, scopes: []string{scope}} } -type ScopedCredential struct { - cred azcore.TokenCredential +type ScopedCredential[T azcore.TokenCredential] struct { + cred T scopes []string } -func (s *ScopedCredential) GetToken(ctx context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) { +func (s *ScopedCredential[T]) GetToken(ctx context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) { return s.cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: s.scopes, EnableCAE: true}) } +type ScopedToken = ScopedCredential[azcore.TokenCredential] +type ScopedAuthenticator ScopedCredential[AuthenticateToken] + +func (s *ScopedAuthenticator) GetToken(ctx context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) { + return s.cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: s.scopes, EnableCAE: true}) +} + +func (s *ScopedAuthenticator) Authenticate(ctx context.Context, _ *policy.TokenRequestOptions) (azidentity.AuthenticationRecord, error) { + return s.cred.Authenticate(ctx, &policy.TokenRequestOptions{Scopes: s.scopes, EnableCAE: true}) +} + type ServiceClient struct { fsc *fileservice.Client bsc *blobservice.Client diff --git a/ste/destReauthPolicy.go b/ste/destReauthPolicy.go new file mode 100644 index 000000000..7b1ac160a --- /dev/null +++ b/ste/destReauthPolicy.go @@ -0,0 +1,114 @@ +package ste + +import ( + "context" + "errors" + "fmt" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" + "github.com/Azure/azure-storage-azcopy/v10/common" + "net/http" + "sync" + "time" +) + +/* +RESPONSE Status: 401 Server failed to authenticate the request. Please refer to the information in the www-authenticate header. +Www-Authenticate: Bearer authorization_uri=https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/oauth2/authorize resource_id=https://storage.azure.com +X-Ms-Error-Code: InvalidAuthenticationInfo + InvalidAuthenticationInfoServer failed to authenticate the request. Please refer to the information in the www-authenticate header. Lifetime validation failed. The token is expired. +*/ + +type destReauthPolicy struct { + cred *common.ScopedAuthenticator +} + +var reauthLock *sync.Cond = sync.NewCond(&sync.Mutex{}) + +func NewDestReauthPolicy(cred *common.ScopedAuthenticator) policy.Policy { + return &destReauthPolicy{cred} +} + +func (d *destReauthPolicy) Do(req *policy.Request) (*http.Response, error) { +retry: + ctx := req.Raw().Context() + debugCtx := context.WithValue(ctx, "destAuthDebug", true) + + clone := req.Clone(ctx) + resp, err := clone.Next() // Initially attempt the request. + + if err != nil || resp.StatusCode != http.StatusOK { // But, if we get back an error... + var authReq *azidentity.AuthenticationRequiredError + var respErr = &azcore.ResponseError{} + + reauth := false + + switch { // Is it an error we can resolve by re-authing? + case errors.As(err, &authReq): + reauth = true + debugCtx = context.WithValue(debugCtx, "reauthSrc", "AuthenticationRequiredError") + case resp.StatusCode == http.StatusUnauthorized: + errors.As(runtime.NewResponseError(resp), &respErr) + reauth = err == nil && + bloberror.HasCode(respErr, bloberror.InvalidAuthenticationInfo) && + len(respErr.RawResponse.Header.Values("WWW-Authenticate")) != 0 + if reauth { + debugCtx = context.WithValue(debugCtx, "reauthSrc", "InvalidAuthenticationInfo") + } + } + + if reauth { // If it is, pull the lock if we can, reauth + m := reauthLock.L.(*sync.Mutex) + + if m.TryLock() { // Fetch the lock and try until we get auth. + for { + if ctx.Value("noPrompt") == nil { + _ = common.GetLifecycleMgr().Prompt("Authentication is required to continue the job. Reauthorize and continue?", common.PromptDetails{ + PromptType: common.EPromptType.Reauth(), + ResponseOptions: []common.ResponseOption{ + common.EResponseOption.Yes(), + }, + }) + } + + _, err = d.cred.Authenticate(debugCtx, &policy.TokenRequestOptions{ + Scopes: []string{}, + }) + + // I (Adele Reed) was initially worried about every case + // Thinking about it further, the worst case is that the job ends automatically, or when the user asks it to end. + // To avoid having to handle every error, we'll catch the cancel case as a way to exit the routine, but otherwise + // we will let it happen, and just retry. + if err == nil { + break + } else { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + select { + case <-ctx.Done(): // If it was us, exit like asked. + return nil, err // If it was us, that's legitimately important. + default: // If it was them, we don't care. + } + } else { + common.GetLifecycleMgr().Info(fmt.Sprintf("Authentication failed, awaiting input to continue: %s", err)) + } + + time.Sleep(time.Second * 5) + } + } + + m.Unlock() + reauthLock.Broadcast() + } else { // Otherwise, wait for a signal that we can try again. + reauthLock.Wait() + } + + // Try the request once more + goto retry + } // If it wasn't, we won't retry, and we'll simply return the error. + } + + return resp, err +} diff --git a/ste/mgr-JobPartMgr.go b/ste/mgr-JobPartMgr.go index 62bf9f686..f7a487291 100644 --- a/ste/mgr-JobPartMgr.go +++ b/ste/mgr-JobPartMgr.go @@ -129,12 +129,15 @@ func (d *dialRateLimiter) DialContext(ctx context.Context, network, address stri return d.dialer.DialContext(ctx, network, address) } -func NewClientOptions(retry policy.RetryOptions, telemetry policy.TelemetryOptions, transport policy.Transporter, log LogOptions, srcCred *common.ScopedCredential) azcore.ClientOptions { +func NewClientOptions(retry policy.RetryOptions, telemetry policy.TelemetryOptions, transport policy.Transporter, log LogOptions, srcCred *common.ScopedToken, dstCred *common.ScopedAuthenticator) azcore.ClientOptions { // Pipeline will look like // [includeResponsePolicy, newAPIVersionPolicy (ignored), NewTelemetryPolicy, perCall, NewRetryPolicy, perRetry, NewLogPolicy, httpHeaderPolicy, bodyDownloadPolicy] perCallPolicies := []policy.Policy{azruntime.NewRequestIDPolicy(), NewVersionPolicy(), newFileUploadRangeFromURLFixPolicy()} // TODO : Default logging policy is not equivalent to old one. tracing HTTP request perRetryPolicies := []policy.Policy{newRetryNotificationPolicy(), newLogPolicy(log), newStatsPolicy()} + if dstCred != nil { + perCallPolicies = append(perRetryPolicies, NewDestReauthPolicy(dstCred)) + } if srcCred != nil { perRetryPolicies = append(perRetryPolicies, NewSourceAuthPolicy(srcCred)) } diff --git a/ste/testJobPartTransferManager_test.go b/ste/testJobPartTransferManager_test.go index 55d8fb32d..0334766c4 100644 --- a/ste/testJobPartTransferManager_test.go +++ b/ste/testJobPartTransferManager_test.go @@ -292,7 +292,7 @@ func (t *testJobPartTransferManager) S2SSourceClientOptions() azcore.ClientOptio httpClient := NewAzcopyHTTPClient(4) - return NewClientOptions(retryOptions, telemetryOptions, httpClient, LogOptions{}, nil) + return NewClientOptions(retryOptions, telemetryOptions, httpClient, LogOptions{}, nil, nil) } func (t *testJobPartTransferManager) CredentialOpOptions() *common.CredentialOpOptions { diff --git a/ste/zt_destReauthPolicy_test.go b/ste/zt_destReauthPolicy_test.go new file mode 100644 index 000000000..b744f3aaa --- /dev/null +++ b/ste/zt_destReauthPolicy_test.go @@ -0,0 +1,135 @@ +package ste + +import ( + "context" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + blobservice "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service" + "github.com/Azure/azure-storage-azcopy/v10/common" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "strings" + "testing" + "time" +) + +const ( + authRequiredResp = ` + + InvalidAuthenticationInfo + placeholder +` + accountPropsResp = ` + +` +) + +type ReauthTransporter struct { + RequireAuth bool +} + +func (r *ReauthTransporter) Do(req *http.Request) (*http.Response, error) { + if r.RequireAuth { // Format this as a retry blob error + h := http.Header{} + h.Add("WWW-Authenticate", "Bearer authorization_uri=https://login.microsoftonline.com/c1cacfe1-4dd7-4d62-b8c5-5b6cf62d10f9/oauth2/authorize resource_id=https://storage.azure.com") + return &http.Response{ + StatusCode: http.StatusUnauthorized, + Status: http.StatusText(http.StatusUnauthorized), + ContentLength: int64(len(authRequiredResp)), + Body: io.NopCloser(strings.NewReader(authRequiredResp)), + Request: req, + Header: h, + }, nil + } + + return &http.Response{ + StatusCode: http.StatusOK, + Status: http.StatusText(http.StatusOK), + ContentLength: int64(len(accountPropsResp)), + Body: io.NopCloser(strings.NewReader(accountPropsResp)), + Request: req, + }, nil +} + +type ReauthTestCred struct { + // ImmediateReauth fires off an azidentity.AuthenticationRequiredError in GetToken + ImmediateReauth bool + + ReauthCallback func(ctx context.Context) +} + +func (r *ReauthTestCred) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { + if r.ImmediateReauth { + return azcore.AccessToken{}, &azidentity.AuthenticationRequiredError{} + } + + return azcore.AccessToken{Token: "foobar", ExpiresOn: time.Now().Add(time.Hour * 24)}, nil +} + +func (r *ReauthTestCred) Authenticate(ctx context.Context, opts *policy.TokenRequestOptions) (azidentity.AuthenticationRecord, error) { + if r.ReauthCallback != nil { + r.ReauthCallback(ctx) + } + + r.ImmediateReauth = false + + return azidentity.AuthenticationRecord{}, nil +} + +// This is not an end-to-end test. But it is an instantaneous validation of the logic. +func TestDestReauthPolicy(t *testing.T) { + rootctx := context.WithValue(context.Background(), "noPrompt", true) + ctx, cancel := context.WithCancel(rootctx) + + reauthed := false + cred := &ReauthTestCred{ + ReauthCallback: func(ctx context.Context) { + reauthed = true + assert.Equal(t, ctx.Value("destAuthDebug"), true, "Expected reauth to occur via the policy") + assert.Equal(t, ctx.Value("destAuthDebug"), true, "Expected reauth to occur via the AuthenticationRequired mechanism") + cancel() + }, + } + + transport := &ReauthTransporter{} + + opts := NewClientOptions( + policy.RetryOptions{}, + policy.TelemetryOptions{}, + transport, + LogOptions{}, + nil, (*common.ScopedAuthenticator)(common.NewScopedCredential[common.AuthenticateToken](cred, common.ECredentialType.OAuthToken())), + ) + + c, err := blobservice.NewClient("https://foobar.blob.core.windows.net/", cred, &blobservice.ClientOptions{ClientOptions: opts}) + assert.NoError(t, err, "Failed to create service client") + + // Initially, fire off a request that will get slapped with an AuthenticationRequired. + cred.ImmediateReauth = true + _, err = c.GetProperties(ctx, nil) + assert.Equal(t, reauthed, true, "Expected reauthentication attempt in request") + + // Reset the context + ctx, cancel = context.WithCancel(rootctx) + ctx = context.WithValue(ctx, "noPrompt", true) + reauthed = false + + // Set the cred & request to require a reauth on round trip + cred.ImmediateReauth = false + transport.RequireAuth = true + + // Turning RequireAuth is probably not necessary since the context is cancelled but... + cred.ReauthCallback = func(ctx context.Context) { + reauthed = true + assert.Equal(t, ctx.Value("destAuthDebug"), true, "Expected reauth to occur via the policy") + cancel() + transport.RequireAuth = false + } + + // Initially, fire off a request that will get slapped with an AuthenticationRequired. + cred.ImmediateReauth = true + _, err = c.GetProperties(ctx, nil) + assert.Equal(t, reauthed, true, "Expected reauthentication attempt in request") +} From 0c060838b02d329c637fbe6e47aebad5b64564d8 Mon Sep 17 00:00:00 2001 From: Adele Reed Date: Fri, 17 Jan 2025 00:04:03 -0800 Subject: [PATCH 02/10] Use constants instead of magic strings --- ste/destReauthPolicy.go | 16 ++++++++++++---- ste/zt_destReauthPolicy_test.go | 20 +++++++++++--------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/ste/destReauthPolicy.go b/ste/destReauthPolicy.go index 7b1ac160a..7d3a50aec 100644 --- a/ste/destReauthPolicy.go +++ b/ste/destReauthPolicy.go @@ -32,10 +32,18 @@ func NewDestReauthPolicy(cred *common.ScopedAuthenticator) policy.Policy { return &destReauthPolicy{cred} } +const ( + destReauthDebugExecuted = "destReauthExec" + destReauthDebugNoPrompt = "destReauthNoPrompt" + destReauthDebugCause = "destReauthCause" + destReauthDebugCauseAuthenticationRequired = "AuthenticationRequiredError" + destReauthDebugCauseInvalidAuthenticationInfo = "InvalidAuthenticationInfoError" +) + func (d *destReauthPolicy) Do(req *policy.Request) (*http.Response, error) { retry: ctx := req.Raw().Context() - debugCtx := context.WithValue(ctx, "destAuthDebug", true) + debugCtx := context.WithValue(ctx, destReauthDebugExecuted, true) clone := req.Clone(ctx) resp, err := clone.Next() // Initially attempt the request. @@ -49,14 +57,14 @@ retry: switch { // Is it an error we can resolve by re-authing? case errors.As(err, &authReq): reauth = true - debugCtx = context.WithValue(debugCtx, "reauthSrc", "AuthenticationRequiredError") + debugCtx = context.WithValue(debugCtx, destReauthDebugCause, destReauthDebugCauseAuthenticationRequired) case resp.StatusCode == http.StatusUnauthorized: errors.As(runtime.NewResponseError(resp), &respErr) reauth = err == nil && bloberror.HasCode(respErr, bloberror.InvalidAuthenticationInfo) && len(respErr.RawResponse.Header.Values("WWW-Authenticate")) != 0 if reauth { - debugCtx = context.WithValue(debugCtx, "reauthSrc", "InvalidAuthenticationInfo") + debugCtx = context.WithValue(debugCtx, destReauthDebugCause, destReauthDebugCauseInvalidAuthenticationInfo) } } @@ -65,7 +73,7 @@ retry: if m.TryLock() { // Fetch the lock and try until we get auth. for { - if ctx.Value("noPrompt") == nil { + if ctx.Value(destReauthDebugNoPrompt) == nil { _ = common.GetLifecycleMgr().Prompt("Authentication is required to continue the job. Reauthorize and continue?", common.PromptDetails{ PromptType: common.EPromptType.Reauth(), ResponseOptions: []common.ResponseOption{ diff --git a/ste/zt_destReauthPolicy_test.go b/ste/zt_destReauthPolicy_test.go index b744f3aaa..c35b31275 100644 --- a/ste/zt_destReauthPolicy_test.go +++ b/ste/zt_destReauthPolicy_test.go @@ -80,15 +80,15 @@ func (r *ReauthTestCred) Authenticate(ctx context.Context, opts *policy.TokenReq // This is not an end-to-end test. But it is an instantaneous validation of the logic. func TestDestReauthPolicy(t *testing.T) { - rootctx := context.WithValue(context.Background(), "noPrompt", true) + rootctx := context.WithValue(context.Background(), destReauthDebugNoPrompt, true) ctx, cancel := context.WithCancel(rootctx) reauthed := false cred := &ReauthTestCred{ ReauthCallback: func(ctx context.Context) { reauthed = true - assert.Equal(t, ctx.Value("destAuthDebug"), true, "Expected reauth to occur via the policy") - assert.Equal(t, ctx.Value("destAuthDebug"), true, "Expected reauth to occur via the AuthenticationRequired mechanism") + assert.Equal(t, ctx.Value(destReauthDebugExecuted), true, "Expected reauth to occur via the policy") + assert.Equal(t, ctx.Value(destReauthDebugCause), destReauthDebugCauseAuthenticationRequired, "Expected reauth to occur via the AuthenticationRequired mechanism") cancel() }, } @@ -113,23 +113,25 @@ func TestDestReauthPolicy(t *testing.T) { // Reset the context ctx, cancel = context.WithCancel(rootctx) - ctx = context.WithValue(ctx, "noPrompt", true) + ctx = context.WithValue(ctx, destReauthDebugNoPrompt, true) reauthed = false - // Set the cred & request to require a reauth on round trip + // =========== InvalidAuthenticationInfo ============ + + // Set the cred & request to require a reauth on round trip, rather than up front, triggering the alternative activation method cred.ImmediateReauth = false transport.RequireAuth = true - // Turning RequireAuth is probably not necessary since the context is cancelled but... + // reset callback cred.ReauthCallback = func(ctx context.Context) { reauthed = true - assert.Equal(t, ctx.Value("destAuthDebug"), true, "Expected reauth to occur via the policy") + assert.Equal(t, ctx.Value(destReauthDebugExecuted), true, "Expected reauth to occur via the policy") + assert.Equal(t, ctx.Value(destReauthDebugCause), destReauthDebugCauseInvalidAuthenticationInfo, "Expected reauth to occur via the InvalidAuthenticationInfo mechanism") cancel() transport.RequireAuth = false } - // Initially, fire off a request that will get slapped with an AuthenticationRequired. - cred.ImmediateReauth = true + // Initially, fire off a request that will get slapped with an InvalidAuthenticationMethod _, err = c.GetProperties(ctx, nil) assert.Equal(t, reauthed, true, "Expected reauthentication attempt in request") } From 8b60e96e50bdcf54ba1e9b250d54f47441e1d1c0 Mon Sep 17 00:00:00 2001 From: Adele Reed Date: Fri, 17 Jan 2025 00:11:16 -0800 Subject: [PATCH 03/10] Resolve build errors --- cmd/credentialUtil.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/credentialUtil.go b/cmd/credentialUtil.go index 4e4f5b784..6b09c48b6 100644 --- a/cmd/credentialUtil.go +++ b/cmd/credentialUtil.go @@ -364,7 +364,7 @@ func isPublic(ctx context.Context, blobResourceURL string, cpkOptions common.Cpk RetryDelay: ste.UploadRetryDelay, MaxRetryDelay: ste.UploadMaxRetryDelay, }, policy.TelemetryOptions{ - ApplicationID: glcm.AddUserAgentPrefix(common.UserAgent), + ApplicationID: common.AddUserAgentPrefix(common.UserAgent), }, nil, ste.LogOptions{}, nil, nil) blobClient, _ := blob.NewClientWithNoCredential(bURLParts.String(), &blob.ClientOptions{ClientOptions: clientOptions}) @@ -397,7 +397,7 @@ func mdAccountNeedsOAuth(ctx context.Context, blobResourceURL string, cpkOptions RetryDelay: ste.UploadRetryDelay, MaxRetryDelay: ste.UploadMaxRetryDelay, }, policy.TelemetryOptions{ - ApplicationID: glcm.AddUserAgentPrefix(common.UserAgent), + ApplicationID: common.AddUserAgentPrefix(common.UserAgent), }, nil, ste.LogOptions{}, nil, nil) blobClient, _ := blob.NewClientWithNoCredential(blobResourceURL, &blob.ClientOptions{ClientOptions: clientOptions}) @@ -591,7 +591,7 @@ func createClientOptions(logger common.ILoggerResetable, srcCred *common.ScopedT RetryDelay: ste.UploadRetryDelay, MaxRetryDelay: ste.UploadMaxRetryDelay, }, policy.TelemetryOptions{ - ApplicationID: glcm.AddUserAgentPrefix(common.UserAgent), + ApplicationID: common.AddUserAgentPrefix(common.UserAgent), }, ste.NewAzcopyHTTPClient(frontEndMaxIdleConnectionsPerHost), logOptions, srcCred, reauthCred) } From 83b02c3e3c6c70cee7b41bbe342db0b208794e81 Mon Sep 17 00:00:00 2001 From: Adele Reed Date: Fri, 17 Jan 2025 10:08:10 -0800 Subject: [PATCH 04/10] lint --- ste/destReauthPolicy.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/ste/destReauthPolicy.go b/ste/destReauthPolicy.go index 7d3a50aec..f29038653 100644 --- a/ste/destReauthPolicy.go +++ b/ste/destReauthPolicy.go @@ -32,12 +32,14 @@ func NewDestReauthPolicy(cred *common.ScopedAuthenticator) policy.Policy { return &destReauthPolicy{cred} } -const ( - destReauthDebugExecuted = "destReauthExec" - destReauthDebugNoPrompt = "destReauthNoPrompt" - destReauthDebugCause = "destReauthCause" - destReauthDebugCauseAuthenticationRequired = "AuthenticationRequiredError" - destReauthDebugCauseInvalidAuthenticationInfo = "InvalidAuthenticationInfoError" +type destReauthDebug string + +var ( + destReauthDebugExecuted destReauthDebug = "executed" + destReauthDebugNoPrompt destReauthDebug = "destReauthNoPrompt" + destReauthDebugCause destReauthDebug = "destReauthCause" + destReauthDebugCauseAuthenticationRequired destReauthDebug = "AuthenticationRequiredError" + destReauthDebugCauseInvalidAuthenticationInfo destReauthDebug = "InvalidAuthenticationInfoError" ) func (d *destReauthPolicy) Do(req *policy.Request) (*http.Response, error) { From 6c142749867e1fd58d37c1dbc245e438c8211d9e Mon Sep 17 00:00:00 2001 From: Adele Reed Date: Fri, 17 Jan 2025 11:40:15 -0800 Subject: [PATCH 05/10] lint --- cmd/syncEnumerator.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/syncEnumerator.go b/cmd/syncEnumerator.go index dc244786a..a02bf962e 100644 --- a/cmd/syncEnumerator.go +++ b/cmd/syncEnumerator.go @@ -218,7 +218,7 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s var dstReauthTok *common.ScopedAuthenticator if at, ok := srcCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. - srcReauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + dstReauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) } var srcTokenCred *common.ScopedToken From 83fe460dd50f2fffc5d6a6953af5deb6126b855c Mon Sep 17 00:00:00 2001 From: Adele Reed Date: Tue, 21 Jan 2025 15:51:59 -0800 Subject: [PATCH 06/10] nil deref --- cmd/zc_enumerator.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cmd/zc_enumerator.go b/cmd/zc_enumerator.go index 907e08250..a1fa87d59 100644 --- a/cmd/zc_enumerator.go +++ b/cmd/zc_enumerator.go @@ -382,9 +382,11 @@ func InitResourceTraverser(resource common.ResourceString, location common.Locat } var reauthTok *common.ScopedAuthenticator - if at, ok := credential.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { - // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. - reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + if credential != nil { + if at, ok := credential.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { + // This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands. + reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken())) + } } options := createClientOptions(azcopyScanningLogger, nil, reauthTok) From 7e83830393611f5964368078cf015d66b06ad2a7 Mon Sep 17 00:00:00 2001 From: Adele Reed Date: Wed, 22 Jan 2025 03:57:26 -0800 Subject: [PATCH 07/10] Backport sharequota test change --- e2etest/zt_newe2e_file_test.go | 33 +++++++++++++-------------------- ste/mgr-JobPartTransferMgr.go | 5 +++-- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/e2etest/zt_newe2e_file_test.go b/e2etest/zt_newe2e_file_test.go index 5e8db457c..77e2fcb9a 100644 --- a/e2etest/zt_newe2e_file_test.go +++ b/e2etest/zt_newe2e_file_test.go @@ -561,28 +561,22 @@ func (s *FileTestSuite) Scenario_UploadFilesWithQuota(svm *ScenarioVariationMana svm.Assert("Quota is 1GB", Equal{Deep: true}, DerefOrZero(shareResource.GetProperties(svm).FileContainerProperties.Quota), int32(1)) - fileNames := []string{"file_1.txt", "file_2.txt"} - - // Create src obj mapping - srcObjs := make(ObjectResourceMappingFlat) - - // Create source files - srcContainer := CreateResource[ContainerResourceManager](svm, GetRootResource(svm, common.ELocation.Local()), - ResourceDefinitionContainer{Objects: srcObjs}) - for _, fileName := range fileNames { - body := NewRandomObjectContentContainer(int64(1) * common.GigaByte) - obj := ResourceDefinitionObject{ - ObjectName: &fileName, - Body: body, - Size: "1.00 GiB", - } - srcObjs[fileName] = obj - CreateResource[ObjectResourceManager](svm, srcContainer, obj) + // Fill the share up + if !svm.Dryrun() { + shareClient := shareResource.(*FileShareResourceManager).internalClient + fileClient := shareClient.NewRootDirectoryClient().NewFileClient("big.txt") + _, err := fileClient.Create(ctx, 990*common.MegaByte, nil) + svm.NoError("Create large file", err) } + srcOverflowObject := CreateResource[ObjectResourceManager](svm, GetRootResource(svm, common.ELocation.Local()), + ResourceDefinitionObject{ + Body: NewRandomObjectContentContainer(common.GigaByte), + }) + stdOut, _ := RunAzCopy(svm, AzCopyCommand{ Verb: AzCopyVerbCopy, - Targets: []ResourceManager{srcContainer, shareResource}, + Targets: []ResourceManager{srcOverflowObject, shareResource}, Flags: CopyFlags{ CopySyncCommonFlags: CopySyncCommonFlags{ Recursive: pointerTo(true), @@ -595,7 +589,7 @@ func (s *FileTestSuite) Scenario_UploadFilesWithQuota(svm *ScenarioVariationMana ValidateContainsError(svm, stdOut, []string{"Increase the file share quota and call Resume command."}) fileMap := shareResource.ListObjects(svm, "", true) - svm.Assert("One file should be uploaded within the quota", Equal{}, len(fileMap)-1, 1) // -1 to Account for root dir in fileMap + svm.Assert("One file should be uploaded within the quota", Equal{}, len(fileMap), 1) // -1 to Account for root dir in fileMap // Increase quota to fit all files newQuota := int32(2) @@ -609,5 +603,4 @@ func (s *FileTestSuite) Scenario_UploadFilesWithQuota(svm *ScenarioVariationMana svm.Assert("Quota should be updated", Equal{}, DerefOrZero(shareResource.GetProperties(svm).FileContainerProperties.Quota), newQuota) - } diff --git a/ste/mgr-JobPartTransferMgr.go b/ste/mgr-JobPartTransferMgr.go index dfe91adf1..0a3020f0a 100644 --- a/ste/mgr-JobPartTransferMgr.go +++ b/ste/mgr-JobPartTransferMgr.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/fileerror" "net/http" "strings" "sync/atomic" @@ -864,8 +865,8 @@ func (jptm *jobPartTransferMgr) failActiveTransfer(typ transferErrorCode, descri common.GetLifecycleMgr().Info(fmt.Sprintf("Authentication failed, it is either not correct, or expired, or does not have the correct permission %s", err.Error())) } - if serviceCode == "ShareSizeLimitReached" { - common.GetLifecycleMgr().Error("Increase the file share quota and call Resume command.") + if fileerror.HasCode(err, "ShareSizeLimitReached") { + common.GetLifecycleMgr().Info("Increase the file share quota and call Resume command.") } // and use the normal cancelling mechanism so that we can exit in a clean and controlled way From dd49ce54ca1c702ac6814012e3fcbec49381bc12 Mon Sep 17 00:00:00 2001 From: Adele Reed Date: Tue, 21 Jan 2025 13:55:43 -0800 Subject: [PATCH 08/10] codespell --- cmd/copyEnumeratorInit.go | 2 +- cmd/zt_parseSize_test.go | 4 ++-- e2etest/newe2e_account_registry.go | 2 +- e2etest/newe2e_task_runazcopy.go | 2 +- ste/md5Comparer.go | 2 +- ste/pageRangeOptimizer.go | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cmd/copyEnumeratorInit.go b/cmd/copyEnumeratorInit.go index 0e53476ee..72cfb392e 100755 --- a/cmd/copyEnumeratorInit.go +++ b/cmd/copyEnumeratorInit.go @@ -314,7 +314,7 @@ func (cca *CookedCopyCmdArgs) initEnumerator(jobPartOrder common.CopyJobPartOrde return NewCopyEnumerator(traverser, filters, processor, finalizer), nil } -// This is condensed down into an individual function as we don't end up re-using the destination traverser at all. +// This is condensed down into an individual function as we don't end up reusing the destination traverser at all. // This is just for the directory check. func (cca *CookedCopyCmdArgs) isDestDirectory(dst common.ResourceString, ctx *context.Context) bool { var err error diff --git a/cmd/zt_parseSize_test.go b/cmd/zt_parseSize_test.go index c0f0b7ca6..5895db40b 100644 --- a/cmd/zt_parseSize_test.go +++ b/cmd/zt_parseSize_test.go @@ -55,7 +55,7 @@ func TestParseSize(t *testing.T) { _, err = ParseSizeString("123T", "foo-bar") // we don't support terabytes a.Equal(expectedError, err.Error()) - _, err = ParseSizeString("abcK", "foo-bar") + _, err = ParseSizeString("abcK", "foo-bar") //codespell:ignore abcK a.Equal(expectedError, err.Error()) -} \ No newline at end of file +} diff --git a/e2etest/newe2e_account_registry.go b/e2etest/newe2e_account_registry.go index 1ba4a8f9d..653eb0a8d 100644 --- a/e2etest/newe2e_account_registry.go +++ b/e2etest/newe2e_account_registry.go @@ -10,7 +10,7 @@ import ( // AccountRegistry is a set of accounts that are intended to be initialized when the tests start running. // Suites and tests should not add to this pool. // todo: long-term, support flexible static configuration of accounts. -var AccountRegistry = map[string]AccountResourceManager{} // For re-using accounts across testing +var AccountRegistry = map[string]AccountResourceManager{} // For reusing accounts across testing func GetAccount(a Asserter, AccountName string) AccountResourceManager { targetAccount, ok := AccountRegistry[AccountName] diff --git a/e2etest/newe2e_task_runazcopy.go b/e2etest/newe2e_task_runazcopy.go index f20b21fb4..04c1a4da1 100644 --- a/e2etest/newe2e_task_runazcopy.go +++ b/e2etest/newe2e_task_runazcopy.go @@ -112,7 +112,7 @@ type AzCopyCommand struct { } type AzCopyEnvironment struct { - // `env:"XYZ"` is re-used but does not inherit the traits of config's env trait. Merely used for low-code mapping. + // `env:"XYZ"` is reused but does not inherit the traits of config's env trait. Merely used for low-code mapping. LogLocation *string `env:"AZCOPY_LOG_LOCATION,defaultfunc:DefaultLogLoc"` JobPlanLocation *string `env:"AZCOPY_JOB_PLAN_LOCATION,defaultfunc:DefaultPlanLoc"` diff --git a/ste/md5Comparer.go b/ste/md5Comparer.go index 5929ab297..87d5af04c 100644 --- a/ste/md5Comparer.go +++ b/ste/md5Comparer.go @@ -64,7 +64,7 @@ func (c *md5Comparer) Check() error { switch c.validationOption { // This code would never be triggered anymore due to the early check that now occurs in xfer-remoteToLocal.go case common.EHashValidationOption.FailIfDifferentOrMissing(): - panic("Transfer should've pre-emptively failed with a missing MD5.") + panic("Transfer should've preemptively failed with a missing MD5.") case common.EHashValidationOption.FailIfDifferent(), common.EHashValidationOption.LogOnly(): c.logAsMissing() diff --git a/ste/pageRangeOptimizer.go b/ste/pageRangeOptimizer.go index a8cc96a0a..0316f0eee 100644 --- a/ste/pageRangeOptimizer.go +++ b/ste/pageRangeOptimizer.go @@ -31,7 +31,7 @@ import ( // isolate the logic to fetch page ranges for a page blob, and check whether a given range has data // for two purposes: // 1. capture the necessary info to do so, so that fetchPages can be invoked anywhere -// 2. open to extending the logic, which could be re-used for both download and s2s scenarios +// 2. open to extending the logic, which could be reused for both download and s2s scenarios type pageRangeOptimizer struct { srcPageBlobClient *pageblob.Client ctx context.Context From 2728563fd69ea99ad3c56c486f3d518957441795 Mon Sep 17 00:00:00 2001 From: Adele Reed Date: Tue, 21 Jan 2025 14:02:54 -0800 Subject: [PATCH 09/10] codespell? --- cmd/zt_parseSize_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/zt_parseSize_test.go b/cmd/zt_parseSize_test.go index 5895db40b..dee421da9 100644 --- a/cmd/zt_parseSize_test.go +++ b/cmd/zt_parseSize_test.go @@ -55,7 +55,7 @@ func TestParseSize(t *testing.T) { _, err = ParseSizeString("123T", "foo-bar") // we don't support terabytes a.Equal(expectedError, err.Error()) - _, err = ParseSizeString("abcK", "foo-bar") //codespell:ignore abcK + _, err = ParseSizeString("abcK", "foo-bar") //codespell:ignore a.Equal(expectedError, err.Error()) } From 4db85aa9a524341301122a4f87a2e8caca9accb9 Mon Sep 17 00:00:00 2001 From: Adele Reed Date: Wed, 22 Jan 2025 08:28:00 -0800 Subject: [PATCH 10/10] Return to Error --- ste/mgr-JobPartTransferMgr.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ste/mgr-JobPartTransferMgr.go b/ste/mgr-JobPartTransferMgr.go index 0a3020f0a..e1d81b878 100644 --- a/ste/mgr-JobPartTransferMgr.go +++ b/ste/mgr-JobPartTransferMgr.go @@ -866,7 +866,7 @@ func (jptm *jobPartTransferMgr) failActiveTransfer(typ transferErrorCode, descri } if fileerror.HasCode(err, "ShareSizeLimitReached") { - common.GetLifecycleMgr().Info("Increase the file share quota and call Resume command.") + common.GetLifecycleMgr().Error("Increase the file share quota and call Resume command.") } // and use the normal cancelling mechanism so that we can exit in a clean and controlled way