Skip to content

Commit 13c61aa

Browse files
committed
add support for caching refresh token in armadactl
Signed-off-by: Dejan Zele Pejchev <[email protected]>
1 parent bb49417 commit 13c61aa

File tree

5 files changed

+244
-12
lines changed

5 files changed

+244
-12
lines changed

pkg/client/auth/oidc/cache.go

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
package oidc
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"time"
9+
10+
"github.com/99designs/keyring"
11+
"golang.org/x/oauth2"
12+
13+
log "github.com/armadaproject/armada/internal/common/logging"
14+
)
15+
16+
const (
17+
keyringServiceName = "armada-oidc"
18+
)
19+
20+
type TokenCache struct {
21+
ring keyring.Keyring
22+
providerUrl string
23+
clientId string
24+
}
25+
26+
type CachedToken struct {
27+
RefreshToken string `json:"refresh_token"`
28+
TokenType string `json:"token_type"`
29+
StoredAt time.Time `json:"stored_at"`
30+
RefreshExpiry *time.Time `json:"refresh_expiry,omitempty"`
31+
}
32+
33+
func NewTokenCache(providerUrl, clientId string) (*TokenCache, error) {
34+
ring, err := keyring.Open(keyring.Config{
35+
ServiceName: keyringServiceName,
36+
AllowedBackends: keyring.AvailableBackends(),
37+
KeychainTrustApplication: true,
38+
KeychainSynchronizable: false,
39+
KeychainAccessibleWhenUnlocked: true,
40+
KWalletAppID: "armada",
41+
KWalletFolder: "armada",
42+
LibSecretCollectionName: "armada",
43+
WinCredPrefix: "armada",
44+
})
45+
if err != nil {
46+
log.Debug("No secure keyring backend available, token caching disabled")
47+
return nil, fmt.Errorf("no secure keyring backend available: %w", err)
48+
}
49+
50+
return &TokenCache{
51+
ring: ring,
52+
providerUrl: providerUrl,
53+
clientId: clientId,
54+
}, nil
55+
}
56+
57+
func (tc *TokenCache) getKey() string {
58+
return fmt.Sprintf("%s:%s", tc.providerUrl, tc.clientId)
59+
}
60+
61+
func (tc *TokenCache) GetCachedRefreshToken() (string, error) {
62+
if tc == nil || tc.ring == nil {
63+
return "", errors.New("token cache not initialized")
64+
}
65+
66+
key := tc.getKey()
67+
item, err := tc.ring.Get(key)
68+
if err != nil {
69+
if errors.Is(err, keyring.ErrKeyNotFound) {
70+
return "", nil // No cached token
71+
}
72+
return "", fmt.Errorf("failed to get token from keyring: %w", err)
73+
}
74+
75+
var cached CachedToken
76+
if err := json.Unmarshal(item.Data, &cached); err != nil {
77+
return "", fmt.Errorf("failed to unmarshal cached token: %w", err)
78+
}
79+
80+
if cached.RefreshExpiry != nil && time.Now().After(*cached.RefreshExpiry) {
81+
_ = tc.DeleteToken()
82+
return "", nil
83+
}
84+
85+
return cached.RefreshToken, nil
86+
}
87+
88+
func (tc *TokenCache) SaveRefreshToken(refreshToken string) error {
89+
if tc == nil || tc.ring == nil {
90+
return errors.New("token cache not initialized")
91+
}
92+
93+
if refreshToken == "" {
94+
return errors.New("refresh token is empty")
95+
}
96+
97+
key := tc.getKey()
98+
cached := CachedToken{
99+
RefreshToken: refreshToken,
100+
TokenType: "Bearer",
101+
StoredAt: time.Now(),
102+
}
103+
104+
data, err := json.Marshal(cached)
105+
if err != nil {
106+
return fmt.Errorf("failed to marshal token: %w", err)
107+
}
108+
109+
item := keyring.Item{
110+
Key: key,
111+
Data: data,
112+
Label: fmt.Sprintf("Armada OIDC Refresh Token (%s)", tc.clientId),
113+
Description: fmt.Sprintf("OAuth2 refresh token for %s", tc.providerUrl),
114+
}
115+
116+
if err := tc.ring.Set(item); err != nil {
117+
return fmt.Errorf("failed to save refresh token to keyring: %w", err)
118+
}
119+
120+
return nil
121+
}
122+
123+
func (tc *TokenCache) DeleteToken() error {
124+
if tc == nil || tc.ring == nil {
125+
return errors.New("token cache not initialized")
126+
}
127+
128+
if err := tc.ring.Remove(tc.getKey()); err != nil && !errors.Is(err, keyring.ErrKeyNotFound) {
129+
return fmt.Errorf("failed to delete token from keyring: %w", err)
130+
}
131+
return nil
132+
}
133+
134+
func RefreshTokenSecurely(ctx context.Context, config *oauth2.Config, refreshToken string, cache *TokenCache) (*oauth2.Token, error) {
135+
if refreshToken == "" {
136+
return nil, errors.New("no refresh token available")
137+
}
138+
139+
oldToken := &oauth2.Token{
140+
RefreshToken: refreshToken,
141+
TokenType: "Bearer",
142+
Expiry: time.Now().Add(-1 * time.Hour),
143+
}
144+
145+
tokenSource := config.TokenSource(ctx, oldToken)
146+
147+
newToken, err := tokenSource.Token()
148+
if err != nil {
149+
return nil, fmt.Errorf("failed to refresh token: %w", err)
150+
}
151+
152+
if newToken.RefreshToken != "" {
153+
if cache != nil {
154+
if saveErr := cache.SaveRefreshToken(newToken.RefreshToken); saveErr != nil {
155+
log.WithError(saveErr).Error("Failed to save refreshed token to cache")
156+
}
157+
}
158+
} else {
159+
newToken.RefreshToken = refreshToken
160+
}
161+
162+
return newToken, nil
163+
}
164+
165+
func TryGetCachedToken(
166+
ctx context.Context,
167+
config *oauth2.Config,
168+
providerUrl string,
169+
clientId string,
170+
cacheEnabled bool,
171+
) (*oauth2.Token, *TokenCache, error) {
172+
if !cacheEnabled {
173+
return nil, nil, nil
174+
}
175+
176+
cache, err := NewTokenCache(providerUrl, clientId)
177+
if err != nil {
178+
log.Debug("Token cache unavailable, proceeding without caching")
179+
return nil, nil, nil
180+
}
181+
182+
cachedRefreshToken, err := cache.GetCachedRefreshToken()
183+
if err != nil || cachedRefreshToken == "" {
184+
return nil, cache, nil
185+
}
186+
187+
newToken, err := RefreshTokenSecurely(ctx, config, cachedRefreshToken, cache)
188+
if err != nil {
189+
_ = cache.DeleteToken()
190+
return nil, cache, nil
191+
}
192+
193+
return newToken, cache, nil
194+
}
195+
196+
func SaveTokenToCache(token *oauth2.Token, cache *TokenCache) {
197+
if cache == nil || token == nil || token.RefreshToken == "" {
198+
return
199+
}
200+
201+
if err := cache.SaveRefreshToken(token.RefreshToken); err != nil {
202+
log.WithError(err).Debug("Failed to save token to cache")
203+
}
204+
}

pkg/client/auth/oidc/device.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ type DeviceDetails struct {
2222
Scopes []string
2323
}
2424

25-
func AuthenticateDevice(config DeviceDetails) (*TokenCredentials, error) {
25+
func AuthenticateDevice(config DeviceDetails, cacheToken bool) (*TokenCredentials, error) {
2626
ctx := context.Background()
2727

2828
httpClient := http.DefaultClient
@@ -63,6 +63,12 @@ func AuthenticateDevice(config DeviceDetails) (*TokenCredentials, error) {
6363
Scopes: scopes,
6464
}
6565

66+
// Try to use cached refresh token if enabled
67+
token, cache, err := TryGetCachedToken(ctx, &oauth, config.ProviderUrl, config.ClientId, cacheToken)
68+
if err == nil && token != nil {
69+
return &TokenCredentials{oauth.TokenSource(ctx, token)}, nil
70+
}
71+
6672
deviceFlowResponse, err := requestDeviceAuthorization(ctx, httpClient, claims.DeviceAuthorizationEndpoint, config.ClientId, scopes)
6773
if err != nil {
6874
return nil, err
@@ -97,6 +103,7 @@ func AuthenticateDevice(config DeviceDetails) (*TokenCredentials, error) {
97103
token, err := requestToken(ctx, httpClient, oauth.Endpoint.TokenURL, config.ClientId, deviceFlowResponse.DeviceCode)
98104
if err == nil {
99105
fmt.Printf("\nAuthentication successful!\n\n")
106+
SaveTokenToCache(token, cache)
100107
return &TokenCredentials{oauth.TokenSource(ctx, token)}, nil
101108
}
102109

pkg/client/auth/oidc/password.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ type ClientPasswordDetails struct {
1515
Password string
1616
}
1717

18-
func AuthenticateWithPassword(config ClientPasswordDetails) (*TokenCredentials, error) {
18+
func AuthenticateWithPassword(config ClientPasswordDetails, cacheToken bool) (*TokenCredentials, error) {
1919
ctx := context.Background()
2020

2121
provider, err := openId.NewProvider(ctx, config.ProviderUrl)
@@ -29,6 +29,16 @@ func AuthenticateWithPassword(config ClientPasswordDetails) (*TokenCredentials,
2929
Endpoint: provider.Endpoint(),
3030
}
3131

32+
// Try to use cached refresh token if enabled
33+
token, cache, err := TryGetCachedToken(ctx, authConfig, config.ProviderUrl, config.ClientId, cacheToken)
34+
if err == nil && token != nil {
35+
return &TokenCredentials{oauth2.ReuseTokenSource(token, &FunctionTokenSource{
36+
func() (*oauth2.Token, error) {
37+
return authConfig.PasswordCredentialsToken(ctx, config.Username, config.Password)
38+
},
39+
})}, nil
40+
}
41+
3242
source := &FunctionTokenSource{
3343
func() (*oauth2.Token, error) {
3444
return authConfig.PasswordCredentialsToken(ctx, config.Username, config.Password)
@@ -38,6 +48,8 @@ func AuthenticateWithPassword(config ClientPasswordDetails) (*TokenCredentials,
3848
if err != nil {
3949
return nil, err
4050
}
51+
52+
SaveTokenToCache(t, cache)
4153
cachedSource := oauth2.ReuseTokenSource(t, source)
4254
return &TokenCredentials{cachedSource}, nil
4355
}

pkg/client/auth/oidc/pkce.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,33 @@ type PKCEDetails struct {
2525
Scopes []string
2626
}
2727

28-
func AuthenticatePkce(config PKCEDetails) (*TokenCredentials, error) {
28+
func AuthenticatePkce(config PKCEDetails, cacheToken bool) (*TokenCredentials, error) {
2929
ctx := context.Background()
3030

31-
result := make(chan *oauth2.Token)
32-
errorResult := make(chan error)
33-
3431
provider, err := openId.NewProvider(ctx, config.ProviderUrl)
3532
if err != nil {
3633
return nil, err
3734
}
3835

39-
localUrl := "localhost:" + strconv.Itoa(int(config.LocalPort))
40-
4136
oauth := oauth2.Config{
4237
ClientID: config.ClientId,
4338
Endpoint: provider.Endpoint(),
44-
RedirectURL: "http://" + localUrl + "/auth/callback",
39+
RedirectURL: "http://localhost:" + strconv.Itoa(int(config.LocalPort)) + "/auth/callback",
4540
Scopes: append(config.Scopes, openId.ScopeOpenID),
4641
}
4742

43+
// Try to use cached refresh token if enabled
44+
token, cache, err := TryGetCachedToken(ctx, &oauth, config.ProviderUrl, config.ClientId, cacheToken)
45+
if err == nil && token != nil {
46+
return &TokenCredentials{oauth.TokenSource(ctx, token)}, nil
47+
}
48+
49+
// Perform interactive authentication if no valid cached token
50+
result := make(chan *oauth2.Token)
51+
errorResult := make(chan error)
52+
53+
localUrl := "localhost:" + strconv.Itoa(int(config.LocalPort))
54+
4855
state := randomStringBase64() // xss protection
4956
challenge := randomStringBase64()
5057
challengeSum := sha256.Sum256([]byte(challenge))
@@ -111,6 +118,7 @@ func AuthenticatePkce(config PKCEDetails) (*TokenCredentials, error) {
111118

112119
select {
113120
case t := <-result:
121+
SaveTokenToCache(t, cache)
114122
return &TokenCredentials{oauth.TokenSource(ctx, t)}, nil
115123
case e := <-errorResult:
116124
return nil, e

pkg/client/connection.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ type ApiConnectionDetails struct {
4545
OpenIdKubernetesAuth oidc.KubernetesDetails
4646
ForceNoTls bool
4747
ExecAuth exec.CommandDetails
48+
CacheRefreshToken bool
4849
}
4950

5051
type ConnectionDetails func() (*ApiConnectionDetails, error)
@@ -101,11 +102,11 @@ func perRpcCredentials(config *ApiConnectionDetails) (credentials.PerRPCCredenti
101102
} else if config.KubernetesNativeAuth.Enabled {
102103
return kubernetes.AuthenticateKubernetesNative(config.KubernetesNativeAuth)
103104
} else if config.OpenIdAuth.ProviderUrl != "" {
104-
return oidc.AuthenticatePkce(config.OpenIdAuth)
105+
return oidc.AuthenticatePkce(config.OpenIdAuth, config.CacheRefreshToken)
105106
} else if config.OpenIdDeviceAuth.ProviderUrl != "" {
106-
return oidc.AuthenticateDevice(config.OpenIdDeviceAuth)
107+
return oidc.AuthenticateDevice(config.OpenIdDeviceAuth, config.CacheRefreshToken)
107108
} else if config.OpenIdPasswordAuth.ProviderUrl != "" {
108-
return oidc.AuthenticateWithPassword(config.OpenIdPasswordAuth)
109+
return oidc.AuthenticateWithPassword(config.OpenIdPasswordAuth, config.CacheRefreshToken)
109110
} else if config.OpenIdClientCredentialsAuth.ProviderUrl != "" {
110111
return oidc.AuthenticateWithClientCredentials(config.OpenIdClientCredentialsAuth)
111112
} else if config.OpenIdKubernetesAuth.ProviderUrl != "" {

0 commit comments

Comments
 (0)