Skip to content

Commit cf746b3

Browse files
committed
feat(armadactl): add OIDC refresh token caching
Add token caching for OIDC auth to avoid repeated browser authentication. The refresh token is securely stored in the system keyring (macOS Keychain, Windows Credential Manager, Linux Secret Service). To enable, add `cacheRefreshToken: true` to your context in ~/.armadactl.yaml and include `offline_access` in your scopes. Note: armadactl must be built with CGO_ENABLED=1 on macOS for keychain access. Signed-off-by: Dejan Zele Pejchev <[email protected]>
1 parent 51cc027 commit cf746b3

File tree

7 files changed

+243
-15
lines changed

7 files changed

+243
-15
lines changed

.goreleaser.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ builds:
138138
goarch:
139139
- amd64
140140
- arm64
141+
overrides:
142+
- goos: darwin
143+
env:
144+
- CGO_ENABLED=1
141145

142146
source:
143147
enabled: true

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ require (
5252
)
5353

5454
require (
55+
github.com/99designs/keyring v1.2.1
5556
github.com/IBM/pgxpoolprometheus v1.1.1
5657
github.com/Masterminds/semver/v3 v3.3.1
5758
github.com/benbjohnson/immutable v0.4.3
@@ -89,7 +90,6 @@ require (
8990
require (
9091
dario.cat/mergo v1.0.1 // indirect
9192
github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect
92-
github.com/99designs/keyring v1.2.1 // indirect
9393
github.com/AlekSi/pointer v1.2.0 // indirect
9494
github.com/AthenZ/athenz v1.10.39 // indirect
9595
github.com/DataDog/zstd v1.5.5 // indirect

pkg/client/auth/oidc/cache.go

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
package oidc
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"time"
9+
10+
"github.com/99designs/keyring"
11+
"golang.org/x/oauth2"
12+
13+
log "github.com/armadaproject/armada/internal/common/logging"
14+
)
15+
16+
const (
17+
keyringServiceName = "armada-oidc"
18+
)
19+
20+
var errCacheNotInitialized = errors.New("token cache not initialized")
21+
22+
type tokenCache struct {
23+
ring keyring.Keyring
24+
providerUrl string
25+
clientId string
26+
}
27+
28+
type cachedToken struct {
29+
RefreshToken string `json:"refresh_token"`
30+
}
31+
32+
func newTokenCache(providerUrl, clientId string) (*tokenCache, error) {
33+
ring, err := keyring.Open(keyring.Config{
34+
ServiceName: keyringServiceName,
35+
AllowedBackends: keyring.AvailableBackends(),
36+
KeychainTrustApplication: true,
37+
KeychainSynchronizable: false,
38+
KeychainAccessibleWhenUnlocked: true,
39+
KWalletAppID: "armada",
40+
KWalletFolder: "armada",
41+
LibSecretCollectionName: "armada",
42+
WinCredPrefix: "armada",
43+
})
44+
if err != nil {
45+
return nil, fmt.Errorf("no secure keyring backend available: %w", err)
46+
}
47+
48+
return &tokenCache{
49+
ring: ring,
50+
providerUrl: providerUrl,
51+
clientId: clientId,
52+
}, nil
53+
}
54+
55+
func (tc *tokenCache) getKey() string {
56+
return fmt.Sprintf("%s:%s", tc.providerUrl, tc.clientId)
57+
}
58+
59+
func (tc *tokenCache) getCachedRefreshToken() (string, error) {
60+
if tc == nil || tc.ring == nil {
61+
return "", errCacheNotInitialized
62+
}
63+
64+
key := tc.getKey()
65+
item, err := tc.ring.Get(key)
66+
if err != nil {
67+
if errors.Is(err, keyring.ErrKeyNotFound) {
68+
return "", nil
69+
}
70+
return "", fmt.Errorf("failed to get token from keyring: %w", err)
71+
}
72+
73+
var cached cachedToken
74+
if err := json.Unmarshal(item.Data, &cached); err != nil {
75+
return "", fmt.Errorf("failed to unmarshal cached token: %w", err)
76+
}
77+
78+
return cached.RefreshToken, nil
79+
}
80+
81+
func (tc *tokenCache) saveRefreshToken(refreshToken string) error {
82+
if tc == nil || tc.ring == nil {
83+
return errCacheNotInitialized
84+
}
85+
86+
if refreshToken == "" {
87+
return errors.New("refresh token is empty")
88+
}
89+
90+
key := tc.getKey()
91+
cached := cachedToken{
92+
RefreshToken: refreshToken,
93+
}
94+
95+
data, err := json.Marshal(cached)
96+
if err != nil {
97+
return fmt.Errorf("failed to marshal token: %w", err)
98+
}
99+
100+
item := keyring.Item{
101+
Key: key,
102+
Data: data,
103+
Label: fmt.Sprintf("Armada OIDC Refresh Token (%s)", tc.clientId),
104+
Description: fmt.Sprintf("OAuth2 refresh token for %s", tc.providerUrl),
105+
}
106+
107+
if err := tc.ring.Set(item); err != nil {
108+
return fmt.Errorf("failed to save refresh token to keyring: %w", err)
109+
}
110+
111+
return nil
112+
}
113+
114+
func (tc *tokenCache) deleteToken() error {
115+
if tc == nil || tc.ring == nil {
116+
return errCacheNotInitialized
117+
}
118+
119+
if err := tc.ring.Remove(tc.getKey()); err != nil && !errors.Is(err, keyring.ErrKeyNotFound) {
120+
return fmt.Errorf("failed to delete token from keyring: %w", err)
121+
}
122+
return nil
123+
}
124+
125+
// refreshToken exchanges a refresh token for a new access token.
126+
// If the provider returns a new refresh token (rotation), it updates the cache.
127+
func refreshToken(ctx context.Context, config *oauth2.Config, refreshToken string, cache *tokenCache) (*oauth2.Token, error) {
128+
if refreshToken == "" {
129+
return nil, errors.New("no refresh token available")
130+
}
131+
132+
oldToken := &oauth2.Token{
133+
RefreshToken: refreshToken,
134+
TokenType: "Bearer",
135+
Expiry: time.Now().Add(-1 * time.Hour),
136+
}
137+
138+
tokenSource := config.TokenSource(ctx, oldToken)
139+
140+
newToken, err := tokenSource.Token()
141+
if err != nil {
142+
return nil, fmt.Errorf("failed to refresh token: %w", err)
143+
}
144+
145+
if newToken.RefreshToken != "" {
146+
if cache != nil {
147+
if saveErr := cache.saveRefreshToken(newToken.RefreshToken); saveErr != nil {
148+
log.WithError(saveErr).Error("Failed to save refreshed token to cache")
149+
}
150+
}
151+
} else {
152+
newToken.RefreshToken = refreshToken
153+
}
154+
155+
return newToken, nil
156+
}
157+
158+
// tryGetCachedToken attempts to retrieve and refresh a cached token.
159+
// Returns (nil, cache) if no valid cached token exists but caching is available.
160+
func tryGetCachedToken(
161+
ctx context.Context,
162+
config *oauth2.Config,
163+
providerUrl string,
164+
clientId string,
165+
cacheEnabled bool,
166+
) (*oauth2.Token, *tokenCache) {
167+
if !cacheEnabled {
168+
return nil, nil
169+
}
170+
171+
cache, err := newTokenCache(providerUrl, clientId)
172+
if err != nil {
173+
log.Warn("Token cache unavailable, proceeding without caching")
174+
return nil, nil
175+
}
176+
177+
cachedRefreshToken, err := cache.getCachedRefreshToken()
178+
if err != nil || cachedRefreshToken == "" {
179+
return nil, cache
180+
}
181+
182+
newToken, err := refreshToken(ctx, config, cachedRefreshToken, cache)
183+
if err != nil {
184+
_ = cache.deleteToken()
185+
return nil, cache
186+
}
187+
188+
return newToken, cache
189+
}
190+
191+
func saveTokenToCache(token *oauth2.Token, cache *tokenCache) {
192+
if cache == nil || token == nil || token.RefreshToken == "" {
193+
return
194+
}
195+
196+
if err := cache.saveRefreshToken(token.RefreshToken); err != nil {
197+
log.WithError(err).Error("Failed to save token to cache")
198+
}
199+
}

pkg/client/auth/oidc/device.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ type DeviceDetails struct {
2222
Scopes []string
2323
}
2424

25-
func AuthenticateDevice(config DeviceDetails) (*TokenCredentials, error) {
25+
func AuthenticateDevice(config DeviceDetails, cacheToken bool) (*TokenCredentials, error) {
2626
ctx := context.Background()
2727

2828
httpClient := http.DefaultClient
@@ -63,6 +63,12 @@ func AuthenticateDevice(config DeviceDetails) (*TokenCredentials, error) {
6363
Scopes: scopes,
6464
}
6565

66+
// Try to use cached refresh token if enabled
67+
token, cache := tryGetCachedToken(ctx, &oauth, config.ProviderUrl, config.ClientId, cacheToken)
68+
if token != nil {
69+
return &TokenCredentials{oauth.TokenSource(ctx, token)}, nil
70+
}
71+
6672
deviceFlowResponse, err := requestDeviceAuthorization(ctx, httpClient, claims.DeviceAuthorizationEndpoint, config.ClientId, scopes)
6773
if err != nil {
6874
return nil, err
@@ -97,6 +103,7 @@ func AuthenticateDevice(config DeviceDetails) (*TokenCredentials, error) {
97103
token, err := requestToken(ctx, httpClient, oauth.Endpoint.TokenURL, config.ClientId, deviceFlowResponse.DeviceCode)
98104
if err == nil {
99105
fmt.Printf("\nAuthentication successful!\n\n")
106+
saveTokenToCache(token, cache)
100107
return &TokenCredentials{oauth.TokenSource(ctx, token)}, nil
101108
}
102109

pkg/client/auth/oidc/password.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ type ClientPasswordDetails struct {
1515
Password string
1616
}
1717

18-
func AuthenticateWithPassword(config ClientPasswordDetails) (*TokenCredentials, error) {
18+
func AuthenticateWithPassword(config ClientPasswordDetails, cacheToken bool) (*TokenCredentials, error) {
1919
ctx := context.Background()
2020

2121
provider, err := openId.NewProvider(ctx, config.ProviderUrl)
@@ -34,10 +34,19 @@ func AuthenticateWithPassword(config ClientPasswordDetails) (*TokenCredentials,
3434
return authConfig.PasswordCredentialsToken(ctx, config.Username, config.Password)
3535
},
3636
}
37+
38+
// Try to use cached refresh token if enabled
39+
token, cache := tryGetCachedToken(ctx, authConfig, config.ProviderUrl, config.ClientId, cacheToken)
40+
if token != nil {
41+
return &TokenCredentials{oauth2.ReuseTokenSource(token, source)}, nil
42+
}
43+
3744
t, err := source.Token()
3845
if err != nil {
3946
return nil, err
4047
}
48+
49+
saveTokenToCache(t, cache)
4150
cachedSource := oauth2.ReuseTokenSource(t, source)
4251
return &TokenCredentials{cachedSource}, nil
4352
}

pkg/client/auth/oidc/pkce.go

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,16 @@ type PKCEDetails struct {
2525
Scopes []string
2626
}
2727

28-
func AuthenticatePkce(config PKCEDetails) (*TokenCredentials, error) {
28+
func AuthenticatePkce(config PKCEDetails, cacheToken bool) (*TokenCredentials, error) {
2929
ctx := context.Background()
3030

31-
result := make(chan *oauth2.Token)
32-
errorResult := make(chan error)
33-
3431
provider, err := openId.NewProvider(ctx, config.ProviderUrl)
3532
if err != nil {
3633
return nil, err
3734
}
3835

39-
localUrl := "localhost:" + strconv.Itoa(int(config.LocalPort))
36+
portStr := strconv.Itoa(int(config.LocalPort))
37+
localUrl := "localhost:" + portStr
4038

4139
oauth := oauth2.Config{
4240
ClientID: config.ClientId,
@@ -45,6 +43,16 @@ func AuthenticatePkce(config PKCEDetails) (*TokenCredentials, error) {
4543
Scopes: append(config.Scopes, openId.ScopeOpenID),
4644
}
4745

46+
// Try to use cached refresh token if enabled
47+
token, cache := tryGetCachedToken(ctx, &oauth, config.ProviderUrl, config.ClientId, cacheToken)
48+
if token != nil {
49+
return &TokenCredentials{oauth.TokenSource(ctx, token)}, nil
50+
}
51+
52+
// Perform interactive authentication if no valid cached token
53+
result := make(chan *oauth2.Token)
54+
errorResult := make(chan error)
55+
4856
state := randomStringBase64() // xss protection
4957
challenge := randomStringBase64()
5058
challengeSum := sha256.Sum256([]byte(challenge))
@@ -104,18 +112,18 @@ func AuthenticatePkce(config PKCEDetails) (*TokenCredentials, error) {
104112
}()
105113

106114
cmd, err := openBrowser("http://" + localUrl)
115+
if err != nil {
116+
return nil, err
117+
}
107118
defer func() {
108119
if err := cmd.Process.Kill(); err != nil {
109120
log.WithStacktrace(err).Error("unable to kill process")
110121
}
111122
}()
112123

113-
if err != nil {
114-
return nil, err
115-
}
116-
117124
select {
118125
case t := <-result:
126+
saveTokenToCache(t, cache)
119127
return &TokenCredentials{oauth.TokenSource(ctx, t)}, nil
120128
case e := <-errorResult:
121129
return nil, e

pkg/client/connection.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ type ApiConnectionDetails struct {
4545
OpenIdKubernetesAuth oidc.KubernetesDetails
4646
ForceNoTls bool
4747
ExecAuth exec.CommandDetails
48+
CacheRefreshToken bool
4849
}
4950

5051
type ConnectionDetails func() (*ApiConnectionDetails, error)
@@ -101,11 +102,11 @@ func perRpcCredentials(config *ApiConnectionDetails) (credentials.PerRPCCredenti
101102
} else if config.KubernetesNativeAuth.Enabled {
102103
return kubernetes.AuthenticateKubernetesNative(config.KubernetesNativeAuth)
103104
} else if config.OpenIdAuth.ProviderUrl != "" {
104-
return oidc.AuthenticatePkce(config.OpenIdAuth)
105+
return oidc.AuthenticatePkce(config.OpenIdAuth, config.CacheRefreshToken)
105106
} else if config.OpenIdDeviceAuth.ProviderUrl != "" {
106-
return oidc.AuthenticateDevice(config.OpenIdDeviceAuth)
107+
return oidc.AuthenticateDevice(config.OpenIdDeviceAuth, config.CacheRefreshToken)
107108
} else if config.OpenIdPasswordAuth.ProviderUrl != "" {
108-
return oidc.AuthenticateWithPassword(config.OpenIdPasswordAuth)
109+
return oidc.AuthenticateWithPassword(config.OpenIdPasswordAuth, config.CacheRefreshToken)
109110
} else if config.OpenIdClientCredentialsAuth.ProviderUrl != "" {
110111
return oidc.AuthenticateWithClientCredentials(config.OpenIdClientCredentialsAuth)
111112
} else if config.OpenIdKubernetesAuth.ProviderUrl != "" {

0 commit comments

Comments
 (0)