diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs index b65fea7436b2..76d1c1f22f58 100644 --- a/object_store/src/client/mod.rs +++ b/object_store/src/client/mod.rs @@ -774,7 +774,7 @@ mod cloud { } /// Override the minimum remaining TTL for a cached token to be used - #[cfg(feature = "aws")] + #[cfg(any(feature = "aws", feature = "gcp"))] pub(crate) fn with_min_ttl(mut self, min_ttl: Duration) -> Self { self.cache = self.cache.with_min_ttl(min_ttl); self diff --git a/object_store/src/client/token.rs b/object_store/src/client/token.rs index f7294190f56d..81ffc110ac0f 100644 --- a/object_store/src/client/token.rs +++ b/object_store/src/client/token.rs @@ -33,8 +33,9 @@ pub(crate) struct TemporaryToken { /// [`TemporaryToken`] based on its expiry #[derive(Debug)] pub(crate) struct TokenCache { - cache: Mutex>>, + cache: Mutex, Instant)>>, min_ttl: Duration, + fetch_backoff: Duration, } impl Default for TokenCache { @@ -42,13 +43,16 @@ impl Default for TokenCache { Self { cache: Default::default(), min_ttl: Duration::from_secs(300), + // How long to wait before re-attempting a token fetch after receiving one that + // is still within the min-ttl + fetch_backoff: Duration::from_millis(100), } } } impl TokenCache { /// Override the minimum remaining TTL for a cached token to be used - #[cfg(feature = "aws")] + #[cfg(any(feature = "aws", feature = "gcp"))] pub(crate) fn with_min_ttl(self, min_ttl: Duration) -> Self { Self { min_ttl, ..self } } @@ -61,20 +65,91 @@ impl TokenCache { let now = Instant::now(); let mut locked = self.cache.lock().await; - if let Some(cached) = locked.as_ref() { + if let Some((cached, fetched_at)) = locked.as_ref() { match cached.expiry { - Some(ttl) if ttl.checked_duration_since(now).unwrap_or_default() > self.min_ttl => { - return Ok(cached.token.clone()); + Some(ttl) => { + if ttl.checked_duration_since(now).unwrap_or_default() > self.min_ttl || + // if we've recently attempted to fetch this token and it's not actually + // expired, we'll wait to re-fetch it and return the cached one + (fetched_at.elapsed() < self.fetch_backoff && ttl.checked_duration_since(now).is_some()) + { + return Ok(cached.token.clone()); + } } None => return Ok(cached.token.clone()), - _ => (), } } let cached = f().await?; let token = cached.token.clone(); - *locked = Some(cached); + *locked = Some((cached, Instant::now())); Ok(token) } } + +#[cfg(test)] +mod test { + use crate::client::token::{TemporaryToken, TokenCache}; + use std::sync::atomic::{AtomicU32, Ordering}; + use std::time::{Duration, Instant}; + + // Helper function to create a token with a specific expiry duration from now + fn create_token(expiry_duration: Option) -> TemporaryToken { + TemporaryToken { + token: "test_token".to_string(), + expiry: expiry_duration.map(|d| Instant::now() + d), + } + } + + #[tokio::test] + async fn test_expired_token_is_refreshed() { + let cache = TokenCache::default(); + static COUNTER: AtomicU32 = AtomicU32::new(0); + + async fn get_token() -> Result, String> { + COUNTER.fetch_add(1, Ordering::SeqCst); + Ok::<_, String>(create_token(Some(Duration::from_secs(0)))) + } + + // Should fetch initial token + let _ = cache.get_or_insert_with(get_token).await.unwrap(); + assert_eq!(COUNTER.load(Ordering::SeqCst), 1); + + tokio::time::sleep(Duration::from_millis(2)).await; + + // Token is expired, so should fetch again + let _ = cache.get_or_insert_with(get_token).await.unwrap(); + assert_eq!(COUNTER.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn test_min_ttl_causes_refresh() { + let cache = TokenCache { + cache: Default::default(), + min_ttl: Duration::from_secs(1), + fetch_backoff: Duration::from_millis(1), + }; + + static COUNTER: AtomicU32 = AtomicU32::new(0); + + async fn get_token() -> Result, String> { + COUNTER.fetch_add(1, Ordering::SeqCst); + Ok::<_, String>(create_token(Some(Duration::from_millis(100)))) + } + + // Initial fetch + let _ = cache.get_or_insert_with(get_token).await.unwrap(); + assert_eq!(COUNTER.load(Ordering::SeqCst), 1); + + // Should not fetch again since not expired and within fetch_backoff + let _ = cache.get_or_insert_with(get_token).await.unwrap(); + assert_eq!(COUNTER.load(Ordering::SeqCst), 1); + + tokio::time::sleep(Duration::from_millis(2)).await; + + // Should fetch, since we've passed fetch_backoff + let _ = cache.get_or_insert_with(get_token).await.unwrap(); + assert_eq!(COUNTER.load(Ordering::SeqCst), 2); + } +} diff --git a/object_store/src/gcp/builder.rs b/object_store/src/gcp/builder.rs index 26cc8211d2dc..fac923c4b9a0 100644 --- a/object_store/src/gcp/builder.rs +++ b/object_store/src/gcp/builder.rs @@ -30,10 +30,13 @@ use serde::{Deserialize, Serialize}; use snafu::{OptionExt, ResultExt, Snafu}; use std::str::FromStr; use std::sync::Arc; +use std::time::Duration; use url::Url; use super::credential::{AuthorizedUserSigningCredentials, InstanceSigningCredentialProvider}; +const TOKEN_MIN_TTL: Duration = Duration::from_secs(4 * 60); + #[derive(Debug, Snafu)] enum Error { #[snafu(display("Missing bucket name"))] @@ -463,13 +466,14 @@ impl GoogleCloudStorageBuilder { )) as _ } else if let Some(credentials) = application_default_credentials.clone() { match credentials { - ApplicationDefaultCredentials::AuthorizedUser(token) => { - Arc::new(TokenCredentialProvider::new( + ApplicationDefaultCredentials::AuthorizedUser(token) => Arc::new( + TokenCredentialProvider::new( token, self.client_options.client()?, self.retry_config.clone(), - )) as _ - } + ) + .with_min_ttl(TOKEN_MIN_TTL), + ) as _, ApplicationDefaultCredentials::ServiceAccount(token) => { Arc::new(TokenCredentialProvider::new( token.token_provider()?, @@ -479,11 +483,14 @@ impl GoogleCloudStorageBuilder { } } } else { - Arc::new(TokenCredentialProvider::new( - InstanceCredentialProvider::default(), - self.client_options.metadata_client()?, - self.retry_config.clone(), - )) as _ + Arc::new( + TokenCredentialProvider::new( + InstanceCredentialProvider::default(), + self.client_options.metadata_client()?, + self.retry_config.clone(), + ) + .with_min_ttl(TOKEN_MIN_TTL), + ) as _ }; let signing_credentials = if let Some(signing_credentials) = self.signing_credentials {