@@ -26,6 +26,7 @@ import (
2626 "context"
2727 "errors"
2828 "fmt"
29+ "strings"
2930 "sync"
3031 "time"
3132
@@ -38,18 +39,20 @@ import (
3839)
3940
4041type Provider struct {
41- opts ProviderOptions
42- numTokens prometheus.Gauge
43- cacheMisses prometheus.Counter
44- serviceAccounts map [serviceaccounts.Reference ]* serviceAccount
45- googleIDTokens map [googleIDTokenReference ]* tokenAndExpiration [string ]
46- nodeServiceAccountRef * serviceaccounts.Reference
47- ctx context.Context
48- cancelCtx context.CancelFunc
49- serviceAccountsMutex sync.Mutex
50- googleIDTokensMutex sync.RWMutex
51- wg sync.WaitGroup
52- semaphore chan struct {}
42+ opts ProviderOptions
43+ numTokens prometheus.Gauge
44+ cacheMisses prometheus.Counter
45+ serviceAccounts map [serviceaccounts.Reference ]* serviceAccount
46+ googleIDTokens map [googleIDTokenReference ]* tokenAndExpiration [string ]
47+ googleScopedAccessTokens map [googleScopedAccessTokenReference ]* tokenAndExpiration [string ]
48+ nodeServiceAccountRef * serviceaccounts.Reference
49+ ctx context.Context
50+ cancelCtx context.CancelFunc
51+ serviceAccountsMutex sync.Mutex
52+ googleIDTokensMutex sync.RWMutex
53+ googleScopedAccessTokensMutex sync.RWMutex
54+ wg sync.WaitGroup
55+ semaphore chan struct {}
5356}
5457
5558type ProviderOptions struct {
@@ -80,17 +83,18 @@ func NewProvider(ctx context.Context, opts ProviderOptions) *Provider {
8083 backgroundCtx , cancel := context .WithCancel (backgroundCtx )
8184
8285 p := & Provider {
83- opts : opts ,
84- numTokens : numTokens ,
85- cacheMisses : cacheMisses ,
86- serviceAccounts : make (map [serviceaccounts.Reference ]* serviceAccount ),
87- googleIDTokens : make (map [googleIDTokenReference ]* tokenAndExpiration [string ]),
88- ctx : backgroundCtx ,
89- cancelCtx : cancel ,
90- semaphore : make (chan struct {}, opts .Concurrency ),
86+ opts : opts ,
87+ numTokens : numTokens ,
88+ cacheMisses : cacheMisses ,
89+ serviceAccounts : make (map [serviceaccounts.Reference ]* serviceAccount ),
90+ googleIDTokens : make (map [googleIDTokenReference ]* tokenAndExpiration [string ]),
91+ googleScopedAccessTokens : make (map [googleScopedAccessTokenReference ]* tokenAndExpiration [string ]),
92+ ctx : backgroundCtx ,
93+ cancelCtx : cancel ,
94+ semaphore : make (chan struct {}, opts .Concurrency ),
9195 }
9296
93- // start garbage collector for google ID tokens
97+ // start garbage collector for input-dependant tokens
9498 p .wg .Add (1 )
9599 go func () {
96100 defer p .wg .Done ()
@@ -108,6 +112,14 @@ func NewProvider(ctx context.Context, opts ProviderOptions) *Provider {
108112 }
109113 }
110114 p .googleIDTokensMutex .Unlock ()
115+
116+ p .googleScopedAccessTokensMutex .Lock ()
117+ for ref , token := range p .googleScopedAccessTokens {
118+ if token .isExpired () {
119+ delete (p .googleScopedAccessTokens , ref )
120+ }
121+ }
122+ p .googleScopedAccessTokensMutex .Unlock ()
111123 }
112124 }
113125 }()
@@ -131,14 +143,65 @@ func (p *Provider) GetServiceAccountToken(ctx context.Context, ref *serviceaccou
131143}
132144
133145func (p * Provider ) GetGoogleAccessTokens (ctx context.Context , saToken string ,
134- googleEmail * string ) (* serviceaccounttokens.AccessTokens , time.Time , error ) {
135- ref := serviceaccounts .ReferenceFromToken (saToken )
136- tokens , err := p .getTokens (ctx , ref )
146+ googleEmail * string , scopes []string ) (* serviceaccounttokens.AccessTokens , time.Time , error ) {
147+
148+ saRef := serviceaccounts .ReferenceFromToken (saToken )
149+
150+ // easy case: no scopes
151+ if len (scopes ) == 0 {
152+ tokens , err := p .getTokens (ctx , saRef )
153+ if err != nil {
154+ return nil , time.Time {}, err
155+ }
156+ token := tokens .googleAccessTokens
157+ return token .token , token .expiration (), nil
158+ }
159+
160+ // handle case with custom scopes
161+
162+ var email string
163+ if googleEmail != nil {
164+ email = * googleEmail
165+ }
166+ ref := googleScopedAccessTokenReference {* saRef , email , strings .Join (scopes , "," )}
167+
168+ // check cache first
169+ p .googleScopedAccessTokensMutex .RLock ()
170+ token , ok := p .googleScopedAccessTokens [ref ]
171+ p .googleScopedAccessTokensMutex .RUnlock ()
172+ if ok && ! token .isExpired () {
173+ return & serviceaccounttokens.AccessTokens {DirectAccess : token .token }, token .expiration (), nil
174+ }
175+
176+ // cache miss or token expired. need to cache a new token, so acquire semaphore to limit concurrency
177+ select {
178+ case p .semaphore <- struct {}{}:
179+ case <- ctx .Done ():
180+ return nil , time.Time {}, fmt .Errorf ("request context done while acquiring semaphore: %w" , ctx .Err ())
181+ case <- p .ctx .Done ():
182+ return nil , time.Time {}, fmt .Errorf ("process terminated while acquiring semaphore: %w" , p .ctx .Err ())
183+ }
184+
185+ tokens , expiration , err := p .opts .Source .GetGoogleAccessTokens (ctx , saToken , googleEmail , scopes )
186+
187+ // release concurrency semaphore
188+ <- p .semaphore
189+
190+ // check error
137191 if err != nil {
138192 return nil , time.Time {}, err
139193 }
140- token := tokens .googleAccessTokens
141- return token .token , token .expiration (), nil
194+
195+ // token issued successfully. cache it and return
196+ tokenString := tokens .DirectAccess
197+ if tokenString == "" {
198+ tokenString = tokens .Impersonated
199+ }
200+ token = newToken (tokenString , expiration )
201+ p .googleScopedAccessTokensMutex .Lock ()
202+ p .googleScopedAccessTokens [ref ] = token
203+ p .googleScopedAccessTokensMutex .Unlock ()
204+ return & serviceaccounttokens.AccessTokens {DirectAccess : token .token }, token .expiration (), nil
142205}
143206
144207func (p * Provider ) GetGoogleIdentityToken (ctx context.Context , saRef * serviceaccounts.Reference ,
0 commit comments