Skip to content

Commit f292ac0

Browse files
authored
Fix the TokenStore getting stuck in a read lock (#3035)
1 parent 73bf86e commit f292ac0

File tree

2 files changed

+96
-8
lines changed

2 files changed

+96
-8
lines changed

common/oauthTokenManager.go

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,17 @@ const DefaultActiveDirectoryEndpoint = "https://login.microsoftonline.com"
6363

6464
const TokenCache = "AzCopyTokenCache"
6565

66+
type CredCacheImplementation interface {
67+
HasCachedToken() (bool, error)
68+
LoadToken() (*OAuthTokenInfo, error)
69+
SaveToken(OAuthTokenInfo) error
70+
RemoveCachedToken() error
71+
}
72+
6673
// UserOAuthTokenManager for token management.
6774
type UserOAuthTokenManager struct {
6875
oauthClient *http.Client
69-
credCache *CredCache
76+
credCache CredCacheImplementation
7077

7178
// Stash the credential info as we delete the environment variable after reading it, and we need to get it multiple times.
7279
stashedInfo *OAuthTokenInfo
@@ -479,7 +486,7 @@ func (credInfo *OAuthTokenInfo) Refresh(ctx context.Context) (*Token, error) {
479486
}
480487

481488
// Single instance token store credential cache shared by entire azcopy process.
482-
var tokenStoreCredCache = NewCredCacheInternalIntegration(CredCacheOptions{
489+
var tokenStoreCredCache CredCacheImplementation = NewCredCacheInternalIntegration(CredCacheOptions{
483490
KeyName: "azcopy/aadtoken/" + strconv.Itoa(os.Getpid()),
484491
ServiceName: "azcopy",
485492
AccountName: "aadtoken/" + strconv.Itoa(os.Getpid()),
@@ -510,8 +517,9 @@ func getAuthorityURL(activeDirectoryEndpoint string) (*url.URL, error) {
510517
const minimumTokenValidDuration = time.Minute * 5
511518

512519
type TokenStoreCredential struct {
513-
token *azcore.AccessToken
514-
lock sync.RWMutex
520+
token *azcore.AccessToken
521+
lock sync.RWMutex
522+
credCache CredCacheImplementation
515523
}
516524

517525
// globalTokenStoreCredential is created to make sure that all
@@ -531,22 +539,26 @@ var globalTokenStoreCredential *TokenStoreCredential
531539
var globalTsc sync.Once
532540

533541
func (tsc *TokenStoreCredential) GetToken(_ context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) {
534-
// if the token we've has not expired, return the same.
542+
// if the token we have has not expired, return the same.
535543
tsc.lock.RLock()
536-
if time.Until(tsc.token.ExpiresOn) > minimumTokenValidDuration {
544+
if rem := time.Until(tsc.token.ExpiresOn); rem > minimumTokenValidDuration {
545+
tsc.lock.RUnlock() // return path, so we must release the read lock here as well.
537546
return *tsc.token, nil
538547
}
539548
tsc.lock.RUnlock()
540549

541550
tsc.lock.Lock()
542551
defer tsc.lock.Unlock()
543-
hasToken, err := tokenStoreCredCache.HasCachedToken()
552+
553+
hasToken, err := tsc.credCache.HasCachedToken()
544554
if err != nil || !hasToken {
555+
AzcopyCurrentJobLogger.Log(LogDebug, fmt.Sprintf("no token found %v", err))
545556
return azcore.AccessToken{}, fmt.Errorf("no cached token found in Token Store Mode(SE), %w", err)
546557
}
547558

548-
tokenInfo, err := tokenStoreCredCache.LoadToken()
559+
tokenInfo, err := tsc.credCache.LoadToken()
549560
if err != nil {
561+
AzcopyCurrentJobLogger.Log(LogDebug, fmt.Sprintf("get token failed %s", err.Error()))
550562
return azcore.AccessToken{}, fmt.Errorf("get cached token failed in Token Store Mode(SE), %w", err)
551563
}
552564

@@ -568,6 +580,7 @@ func GetTokenStoreCredential(accessToken string, expiresOn time.Time) azcore.Tok
568580
Token: accessToken,
569581
ExpiresOn: expiresOn,
570582
},
583+
credCache: tokenStoreCredCache,
571584
}
572585
})
573586
return globalTokenStoreCredential

common/oauthTokenManager_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,17 @@ package common
2222

2323
import (
2424
"context"
25+
"encoding/json"
26+
"errors"
2527
"fmt"
28+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
29+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
30+
"github.com/stretchr/testify/assert"
2631
"os"
2732
"reflect"
2833
"strconv"
2934
"testing"
35+
"time"
3036
)
3137

3238
const tokenInfoJson = `{
@@ -183,3 +189,72 @@ func TestUserOAuthTokenManager_GetTokenInfo(t *testing.T) {
183189
})
184190
}
185191
}
192+
193+
type TestCredCache struct {
194+
stashed *OAuthTokenInfo
195+
}
196+
197+
func (d *TestCredCache) HasCachedToken() (bool, error) {
198+
return d.stashed != nil, nil
199+
}
200+
201+
func (d *TestCredCache) LoadToken() (*OAuthTokenInfo, error) {
202+
if d.stashed == nil {
203+
return nil, errors.New("no cached token found")
204+
}
205+
206+
return d.stashed, nil
207+
}
208+
209+
func (d *TestCredCache) SaveToken(info OAuthTokenInfo) error {
210+
d.stashed = &info
211+
return nil
212+
}
213+
214+
func (d *TestCredCache) RemoveCachedToken() error {
215+
if d.stashed == nil {
216+
return errors.New("no cached token found")
217+
}
218+
219+
d.stashed = nil
220+
return nil
221+
}
222+
223+
func TestTokenStoreCredentialHang(t *testing.T) {
224+
tok := &azcore.AccessToken{
225+
Token: "asdf",
226+
ExpiresOn: time.Now().Add(minimumTokenValidDuration * 2), // we want to hit that if statement at the start and get it into the read lock
227+
}
228+
229+
tsc := &TokenStoreCredential{
230+
token: tok,
231+
credCache: &TestCredCache{
232+
stashed: &OAuthTokenInfo{
233+
Token: Token{
234+
AccessToken: "foobar",
235+
ExpiresOn: json.Number(fmt.Sprint(tok.ExpiresOn.Unix())),
236+
},
237+
},
238+
},
239+
}
240+
241+
// Prior to this PR, we'd get locked into a read state doing this, because the if statement didn't contain a way out of the read lock.
242+
outTok, err := tsc.GetToken(context.Background(), policy.TokenRequestOptions{})
243+
assert.NoError(t, err)
244+
assert.Equal(t, tok.Token, outTok.Token)
245+
246+
// we shouldn't get blocked here, otherwise we have problems.
247+
assert.Equal(t, true, tsc.lock.TryLock())
248+
tsc.lock.Unlock()
249+
250+
tok.ExpiresOn = time.Now() // now it should refresh
251+
252+
// We shouldn't get caught here at all
253+
outTok, err = tsc.GetToken(context.Background(), policy.TokenRequestOptions{})
254+
assert.NoError(t, err)
255+
assert.Equal(t, "foobar", outTok.Token) // ensure it refreshed
256+
257+
// we shouldn't get blocked here, otherwise we have problems.
258+
assert.Equal(t, true, tsc.lock.TryLock())
259+
tsc.lock.Unlock()
260+
}

0 commit comments

Comments
 (0)