Skip to content

Commit e11ecec

Browse files
Implement a policy to reauthorize the credential when invalid (#2887)
* Implement a policy to reauthorize the credential when invalid * Use constants instead of magic strings * Resolve build errors * lint * lint * nil deref * Backport sharequota test change * codespell * codespell? * Return to Error --------- Co-authored-by: Gauri Lamunion <[email protected]>
1 parent 30eb5f6 commit e11ecec

25 files changed

+421
-62
lines changed

cmd/copy.go

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,7 +1310,8 @@ func (cca *CookedCopyCmdArgs) processRedirectionDownload(blobResource common.Res
13101310
}
13111311

13121312
// step 1: create client options
1313-
options := &blockblob.ClientOptions{ClientOptions: createClientOptions(azcopyScanningLogger, nil)}
1313+
// note: dstCred is nil, as we could not reauth effectively because stdout is a pipe.
1314+
options := &blockblob.ClientOptions{ClientOptions: createClientOptions(azcopyScanningLogger, nil, nil)}
13141315

13151316
// step 2: parse source url
13161317
u, err := blobResource.FullURL()
@@ -1386,8 +1387,15 @@ func (cca *CookedCopyCmdArgs) processRedirectionUpload(blobResource common.Resou
13861387
return fmt.Errorf("fatal: cannot find auth on destination blob URL: %s", err.Error())
13871388
}
13881389

1390+
var reauthTok *common.ScopedAuthenticator
1391+
if at, ok := credInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok {
1392+
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
1393+
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
1394+
}
1395+
13891396
// step 0: initialize pipeline
1390-
options := &blockblob.ClientOptions{ClientOptions: createClientOptions(common.AzcopyCurrentJobLogger, nil)}
1397+
// Reauthentication is theoretically possible here, since stdin is blocked.
1398+
options := &blockblob.ClientOptions{ClientOptions: createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok)}
13911399

13921400
// step 1: parse destination url
13931401
u, err := blobResource.FullURL()
@@ -1569,18 +1577,25 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) {
15691577
},
15701578
}
15711579

1572-
options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
1580+
srcCredInfo, err := cca.getSrcCredential(ctx, &jobPartOrder)
1581+
if err != nil {
1582+
return err
1583+
}
1584+
1585+
var srcReauth *common.ScopedAuthenticator
1586+
if at, ok := srcCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok {
1587+
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
1588+
srcReauth = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
1589+
}
1590+
1591+
options := createClientOptions(common.AzcopyCurrentJobLogger, nil, srcReauth)
15731592
var azureFileSpecificOptions any
15741593
if cca.FromTo.From() == common.ELocation.File() {
15751594
azureFileSpecificOptions = &common.FileClientOptions{
15761595
AllowTrailingDot: cca.trailingDot.IsEnabled(),
15771596
}
15781597
}
15791598

1580-
srcCredInfo, err := cca.getSrcCredential(ctx, &jobPartOrder)
1581-
if err != nil {
1582-
return err
1583-
}
15841599
jobPartOrder.SrcServiceClient, err = common.GetServiceClientForLocation(
15851600
cca.FromTo.From(),
15861601
cca.Source,
@@ -1600,11 +1615,17 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) {
16001615
}
16011616
}
16021617

1603-
var srcCred *common.ScopedCredential
1618+
var dstReauthTok *common.ScopedAuthenticator
1619+
if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok {
1620+
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
1621+
dstReauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
1622+
}
1623+
1624+
var srcCred *common.ScopedToken
16041625
if cca.FromTo.IsS2S() && srcCredInfo.CredentialType.IsAzureOAuth() {
16051626
srcCred = common.NewScopedCredential(srcCredInfo.OAuthTokenInfo.TokenCredential, srcCredInfo.CredentialType)
16061627
}
1607-
options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred)
1628+
options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred, dstReauthTok)
16081629
jobPartOrder.DstServiceClient, err = common.GetServiceClientForLocation(
16091630
cca.FromTo.To(),
16101631
cca.Destination,

cmd/copyEnumeratorInit.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ func (cca *CookedCopyCmdArgs) initEnumerator(jobPartOrder common.CopyJobPartOrde
314314
return NewCopyEnumerator(traverser, filters, processor, finalizer), nil
315315
}
316316

317-
// This is condensed down into an individual function as we don't end up re-using the destination traverser at all.
317+
// This is condensed down into an individual function as we don't end up reusing the destination traverser at all.
318318
// This is just for the directory check.
319319
func (cca *CookedCopyCmdArgs) isDestDirectory(dst common.ResourceString, ctx *context.Context) bool {
320320
var err error
@@ -428,7 +428,13 @@ func (cca *CookedCopyCmdArgs) createDstContainer(containerName string, dstWithSA
428428
return err
429429
}
430430

431-
options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
431+
var reauthTok *common.ScopedAuthenticator
432+
if at, ok := dstCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok {
433+
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
434+
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
435+
}
436+
437+
options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok)
432438

433439
sc, err := common.GetServiceClientForLocation(
434440
cca.FromTo.To(),

cmd/credentialUtil.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ func isPublic(ctx context.Context, blobResourceURL string, cpkOptions common.Cpk
365365
MaxRetryDelay: ste.UploadMaxRetryDelay,
366366
}, policy.TelemetryOptions{
367367
ApplicationID: common.AddUserAgentPrefix(common.UserAgent),
368-
}, nil, ste.LogOptions{}, nil)
368+
}, nil, ste.LogOptions{}, nil, nil)
369369

370370
blobClient, _ := blob.NewClientWithNoCredential(bURLParts.String(), &blob.ClientOptions{ClientOptions: clientOptions})
371371
bURLParts.BlobName = ""
@@ -398,7 +398,7 @@ func mdAccountNeedsOAuth(ctx context.Context, blobResourceURL string, cpkOptions
398398
MaxRetryDelay: ste.UploadMaxRetryDelay,
399399
}, policy.TelemetryOptions{
400400
ApplicationID: common.AddUserAgentPrefix(common.UserAgent),
401-
}, nil, ste.LogOptions{}, nil)
401+
}, nil, ste.LogOptions{}, nil, nil)
402402

403403
blobClient, _ := blob.NewClientWithNoCredential(blobResourceURL, &blob.ClientOptions{ClientOptions: clientOptions})
404404
_, err := blobClient.GetProperties(ctx, &blob.GetPropertiesOptions{CPKInfo: cpkOptions.GetCPKInfo()})
@@ -577,7 +577,7 @@ func getCredentialType(ctx context.Context, raw rawFromToInfo, cpkOptions common
577577
// createClientOptions creates generic client options which are required to create any
578578
// client to interact with storage service. Default options are modified to suit azcopy.
579579
// srcCred is required in cases where source is authenticated via oAuth for S2S transfers
580-
func createClientOptions(logger common.ILoggerResetable, srcCred *common.ScopedCredential) azcore.ClientOptions {
580+
func createClientOptions(logger common.ILoggerResetable, srcCred *common.ScopedToken, reauthCred *common.ScopedAuthenticator) azcore.ClientOptions {
581581
logOptions := ste.LogOptions{}
582582

583583
if logger != nil {
@@ -592,7 +592,7 @@ func createClientOptions(logger common.ILoggerResetable, srcCred *common.ScopedC
592592
MaxRetryDelay: ste.UploadMaxRetryDelay,
593593
}, policy.TelemetryOptions{
594594
ApplicationID: common.AddUserAgentPrefix(common.UserAgent),
595-
}, ste.NewAzcopyHTTPClient(frontEndMaxIdleConnectionsPerHost), logOptions, srcCred)
595+
}, ste.NewAzcopyHTTPClient(frontEndMaxIdleConnectionsPerHost), logOptions, srcCred, reauthCred)
596596
}
597597

598598
const frontEndMaxIdleConnectionsPerHost = http.DefaultMaxIdleConnsPerHost

cmd/jobsResume.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,13 +299,19 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients(
299299
}
300300
}
301301

302-
options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
302+
var reauthTok *common.ScopedAuthenticator
303+
if at, ok := tc.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way.
304+
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
305+
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
306+
}
303307
jobID, err := common.ParseJobID(rca.jobID)
304308
if err != nil {
305309
// Error for invalid JobId format
306310
return nil, nil, fmt.Errorf("error parsing the jobId %s. Failed with error %s", rca.jobID, err.Error())
307311
}
308312

313+
// But we don't want to supply a reauth token if we're not using OAuth. That could cause problems if say, a SAS is invalid.
314+
options := createClientOptions(common.AzcopyCurrentJobLogger, nil, common.Iff(srcCredType.IsAzureOAuth(), reauthTok, nil))
309315
var getJobDetailsResponse common.GetJobDetailsResponse
310316
// Get job details from the STE
311317
Rpc(common.ERpcCmd.GetJobDetails(),
@@ -326,11 +332,11 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients(
326332
return nil, nil, err
327333
}
328334

329-
var srcCred *common.ScopedCredential
335+
var srcCred *common.ScopedToken
330336
if fromTo.IsS2S() && srcCredType.IsAzureOAuth() {
331337
srcCred = common.NewScopedCredential(tc, srcCredType)
332338
}
333-
options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred)
339+
options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred, common.Iff(dstCredType.IsAzureOAuth(), reauthTok, nil))
334340
var fileClientOptions any
335341
if fromTo.To() == common.ELocation.File() {
336342
fileClientOptions = &common.FileClientOptions{

cmd/make.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,14 @@ func (cookedArgs cookedMakeCmdArgs) process() (err error) {
9090
return err
9191
}
9292

93+
var reauthTok *common.ScopedAuthenticator
94+
if at, ok := credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way.
95+
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
96+
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
97+
}
98+
9399
// Note : trailing dot is only applicable to file operations anyway, so setting this to false
94-
options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
100+
options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok)
95101
resourceURL := cookedArgs.resourceURL.String()
96102
cred := credentialInfo.OAuthTokenInfo.TokenCredential
97103

cmd/removeEnumerator.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,14 @@ func newRemoveEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator, er
9090
if !from.SupportsTrailingDot() {
9191
cca.trailingDot = common.ETrailingDotOption.Disable()
9292
}
93-
options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
93+
94+
var reauthTok *common.ScopedAuthenticator
95+
if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way.
96+
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
97+
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
98+
}
99+
100+
options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok)
94101
var fileClientOptions any
95102
if cca.FromTo.From() == common.ELocation.File() {
96103
fileClientOptions = &common.FileClientOptions{AllowTrailingDot: cca.trailingDot.IsEnabled()}
@@ -142,7 +149,13 @@ func newRemoveEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator, er
142149
func removeBfsResources(cca *CookedCopyCmdArgs) (err error) {
143150
ctx := context.WithValue(context.Background(), ste.ServiceAPIVersionOverride, ste.DefaultServiceApiVersion)
144151
sourceURL, _ := cca.Source.String()
145-
options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
152+
var reauthTok *common.ScopedAuthenticator
153+
if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way.
154+
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
155+
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
156+
}
157+
158+
options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok)
146159

147160
targetServiceClient, err := common.GetServiceClientForLocation(cca.FromTo.From(), cca.Source, cca.credentialInfo.CredentialType, cca.credentialInfo.OAuthTokenInfo.TokenCredential, &options, nil)
148161
if err != nil {

cmd/root.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ func beginDetectNewVersion() chan struct{} {
311311
PrintOlderVersion(*cachedVersion, *localVersion)
312312
} else {
313313
// step 2: initialize pipeline
314-
options := createClientOptions(nil, nil)
314+
options := createClientOptions(nil, nil, nil)
315315

316316
// step 3: start download
317317
blobClient, err := blob.NewClientWithNoCredential(versionMetadataUrl, &blob.ClientOptions{ClientOptions: options})

cmd/setPropertiesEnumerator.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,13 @@ func setPropertiesEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator
7272
jobsAdmin.JobsAdmin.LogToJobLog(message, common.LogInfo)
7373
}
7474

75-
options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
75+
var reauthTok *common.ScopedAuthenticator
76+
if at, ok := cca.credentialInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok { // We don't need two different tokens here since it gets passed in just the same either way.
77+
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
78+
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
79+
}
80+
81+
options := createClientOptions(common.AzcopyCurrentJobLogger, nil, reauthTok)
7682
var fileClientOptions any
7783
if cca.FromTo.From() == common.ELocation.File() {
7884
fileClientOptions = &common.FileClientOptions{AllowTrailingDot: cca.trailingDot.IsEnabled()}

cmd/syncEnumerator.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,13 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s
179179
},
180180
}
181181

182-
options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
182+
var srcReauthTok *common.ScopedAuthenticator
183+
if at, ok := srcCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok {
184+
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
185+
srcReauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
186+
}
187+
188+
options := createClientOptions(common.AzcopyCurrentJobLogger, nil, srcReauthTok)
183189

184190
// Create Source Client.
185191
var azureFileSpecificOptions any
@@ -209,12 +215,18 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s
209215
}
210216
}
211217

212-
var srcTokenCred *common.ScopedCredential
218+
var dstReauthTok *common.ScopedAuthenticator
219+
if at, ok := srcCredInfo.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok {
220+
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
221+
dstReauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
222+
}
223+
224+
var srcTokenCred *common.ScopedToken
213225
if cca.fromTo.IsS2S() && srcCredInfo.CredentialType.IsAzureOAuth() {
214226
srcTokenCred = common.NewScopedCredential(srcCredInfo.OAuthTokenInfo.TokenCredential, srcCredInfo.CredentialType)
215227
}
216228

217-
options = createClientOptions(common.AzcopyCurrentJobLogger, srcTokenCred)
229+
options = createClientOptions(common.AzcopyCurrentJobLogger, srcTokenCred, dstReauthTok)
218230
copyJobTemplate.DstServiceClient, err = common.GetServiceClientForLocation(
219231
cca.fromTo.To(),
220232
cca.destination,

cmd/versionChecker_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ func TestCheckReleaseMetadata(t *testing.T) {
215215
a := assert.New(t)
216216

217217
// sanity test for checking if the release metadata exists and can be downloaded
218-
options := createClientOptions(nil, nil)
218+
options := createClientOptions(nil, nil, nil)
219219

220220
blobClient, err := blob.NewClientWithNoCredential(versionMetadataUrl, &blob.ClientOptions{ClientOptions: options})
221221
a.NoError(err)

cmd/zc_enumerator.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,15 @@ func InitResourceTraverser(resource common.ResourceString, location common.Locat
381381
return output, nil
382382
}
383383

384-
options := createClientOptions(azcopyScanningLogger, nil)
384+
var reauthTok *common.ScopedAuthenticator
385+
if credential != nil {
386+
if at, ok := credential.OAuthTokenInfo.TokenCredential.(common.AuthenticateToken); ok {
387+
// This will cause a reauth with StorageScope, which is fine, that's the original Authenticate call as it stands.
388+
reauthTok = (*common.ScopedAuthenticator)(common.NewScopedCredential(at, common.ECredentialType.OAuthToken()))
389+
}
390+
}
391+
392+
options := createClientOptions(azcopyScanningLogger, nil, reauthTok)
385393

386394
switch location {
387395
case common.ELocation.Local():

cmd/zt_parseSize_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func TestParseSize(t *testing.T) {
5555
_, err = ParseSizeString("123T", "foo-bar") // we don't support terabytes
5656
a.Equal(expectedError, err.Error())
5757

58-
_, err = ParseSizeString("abcK", "foo-bar")
58+
_, err = ParseSizeString("abcK", "foo-bar") //codespell:ignore
5959
a.Equal(expectedError, err.Error())
6060

61-
}
61+
}

common/oauthTokenManager.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,11 @@ func (credInfo *OAuthTokenInfo) GetDeviceCodeCredential() (azcore.TokenCredentia
701701
Cloud: cloud.Configuration{ActiveDirectoryAuthorityHost: authorityHost.String()},
702702
Transport: newAzcopyHTTPClient(),
703703
},
704+
UserPrompt: func(ctx context.Context, message azidentity.DeviceCodeMessage) error {
705+
lcm.Info(fmt.Sprintf("Authentication is required. To sign in, open the webpage %s and enter the code %s to authenticate.",
706+
Iff(message.VerificationURL != "", message.VerificationURL, "https://aka.ms/devicelogin"), message.UserCode))
707+
return nil
708+
},
704709
})
705710
if err != nil {
706711
return nil, err
@@ -727,6 +732,11 @@ func (credInfo *OAuthTokenInfo) GetDeviceCodeCredential() (azcore.TokenCredentia
727732
return tc, nil
728733
}
729734

735+
type AuthenticateToken interface {
736+
azcore.TokenCredential
737+
Authenticate(ctx context.Context, opts *policy.TokenRequestOptions) (azidentity.AuthenticationRecord, error)
738+
}
739+
730740
func (credInfo *OAuthTokenInfo) GetTokenCredential() (azcore.TokenCredential, error) {
731741
// Token Credential is cached.
732742
if credInfo.TokenCredential != nil {

common/output.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ var EPromptType = PromptType("")
6868

6969
type PromptType string
7070

71+
func (PromptType) Reauth() PromptType { return PromptType("Reauth") }
7172
func (PromptType) Cancel() PromptType { return PromptType("Cancel") }
7273
func (PromptType) Overwrite() PromptType { return PromptType("Overwrite") }
7374
func (PromptType) DeleteDestination() PromptType { return PromptType("DeleteDestination") }

common/util.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
78
"net"
89
"net/url"
910
"strings"
@@ -231,7 +232,7 @@ func GetServiceClientForLocation(loc Location,
231232
// NewScopedCredential takes in a credInfo object and returns ScopedCredential
232233
// if credentialType is either MDOAuth or oAuth. For anything else,
233234
// nil is returned
234-
func NewScopedCredential(cred azcore.TokenCredential, credType CredentialType) *ScopedCredential {
235+
func NewScopedCredential[T azcore.TokenCredential](cred T, credType CredentialType) *ScopedCredential[T] {
235236
var scope string
236237
if !credType.IsAzureOAuth() {
237238
return nil
@@ -240,18 +241,29 @@ func NewScopedCredential(cred azcore.TokenCredential, credType CredentialType) *
240241
} else if credType == ECredentialType.OAuthToken() {
241242
scope = StorageScope
242243
}
243-
return &ScopedCredential{cred: cred, scopes: []string{scope}}
244+
return &ScopedCredential[T]{cred: cred, scopes: []string{scope}}
244245
}
245246

246-
type ScopedCredential struct {
247-
cred azcore.TokenCredential
247+
type ScopedCredential[T azcore.TokenCredential] struct {
248+
cred T
248249
scopes []string
249250
}
250251

251-
func (s *ScopedCredential) GetToken(ctx context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) {
252+
func (s *ScopedCredential[T]) GetToken(ctx context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) {
252253
return s.cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: s.scopes, EnableCAE: true})
253254
}
254255

256+
type ScopedToken = ScopedCredential[azcore.TokenCredential]
257+
type ScopedAuthenticator ScopedCredential[AuthenticateToken]
258+
259+
func (s *ScopedAuthenticator) GetToken(ctx context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) {
260+
return s.cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: s.scopes, EnableCAE: true})
261+
}
262+
263+
func (s *ScopedAuthenticator) Authenticate(ctx context.Context, _ *policy.TokenRequestOptions) (azidentity.AuthenticationRecord, error) {
264+
return s.cred.Authenticate(ctx, &policy.TokenRequestOptions{Scopes: s.scopes, EnableCAE: true})
265+
}
266+
255267
type ServiceClient struct {
256268
fsc *fileservice.Client
257269
bsc *blobservice.Client

0 commit comments

Comments
 (0)