Skip to content

Commit

Permalink
Lower GCP token min_ttl to 4 minutes and add backoff to token refresh…
Browse files Browse the repository at this point in the history
… logic (#6638)
  • Loading branch information
mwylde authored Oct 30, 2024
1 parent 7bcc1ad commit 933d348
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 17 deletions.
2 changes: 1 addition & 1 deletion object_store/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ mod cloud {
}

/// Override the minimum remaining TTL for a cached token to be used
#[cfg(feature = "aws")]
#[cfg(any(feature = "aws", feature = "gcp"))]
pub(crate) fn with_min_ttl(mut self, min_ttl: Duration) -> Self {
self.cache = self.cache.with_min_ttl(min_ttl);
self
Expand Down
89 changes: 82 additions & 7 deletions object_store/src/client/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,26 @@ pub(crate) struct TemporaryToken<T> {
/// [`TemporaryToken`] based on its expiry
#[derive(Debug)]
pub(crate) struct TokenCache<T> {
cache: Mutex<Option<TemporaryToken<T>>>,
cache: Mutex<Option<(TemporaryToken<T>, Instant)>>,
min_ttl: Duration,
fetch_backoff: Duration,
}

impl<T> Default for TokenCache<T> {
fn default() -> Self {
Self {
cache: Default::default(),
min_ttl: Duration::from_secs(300),
// How long to wait before re-attempting a token fetch after receiving one that
// is still within the min-ttl
fetch_backoff: Duration::from_millis(100),
}
}
}

impl<T: Clone + Send> TokenCache<T> {
/// Override the minimum remaining TTL for a cached token to be used
#[cfg(feature = "aws")]
#[cfg(any(feature = "aws", feature = "gcp"))]
pub(crate) fn with_min_ttl(self, min_ttl: Duration) -> Self {
Self { min_ttl, ..self }
}
Expand All @@ -61,20 +65,91 @@ impl<T: Clone + Send> TokenCache<T> {
let now = Instant::now();
let mut locked = self.cache.lock().await;

if let Some(cached) = locked.as_ref() {
if let Some((cached, fetched_at)) = locked.as_ref() {
match cached.expiry {
Some(ttl) if ttl.checked_duration_since(now).unwrap_or_default() > self.min_ttl => {
return Ok(cached.token.clone());
Some(ttl) => {
if ttl.checked_duration_since(now).unwrap_or_default() > self.min_ttl ||
// if we've recently attempted to fetch this token and it's not actually
// expired, we'll wait to re-fetch it and return the cached one
(fetched_at.elapsed() < self.fetch_backoff && ttl.checked_duration_since(now).is_some())
{
return Ok(cached.token.clone());
}
}
None => return Ok(cached.token.clone()),
_ => (),
}
}

let cached = f().await?;
let token = cached.token.clone();
*locked = Some(cached);
*locked = Some((cached, Instant::now()));

Ok(token)
}
}

#[cfg(test)]
mod test {
use crate::client::token::{TemporaryToken, TokenCache};
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, Instant};

// Helper function to create a token with a specific expiry duration from now
fn create_token(expiry_duration: Option<Duration>) -> TemporaryToken<String> {
TemporaryToken {
token: "test_token".to_string(),
expiry: expiry_duration.map(|d| Instant::now() + d),
}
}

#[tokio::test]
async fn test_expired_token_is_refreshed() {
let cache = TokenCache::default();
static COUNTER: AtomicU32 = AtomicU32::new(0);

async fn get_token() -> Result<TemporaryToken<String>, String> {
COUNTER.fetch_add(1, Ordering::SeqCst);
Ok::<_, String>(create_token(Some(Duration::from_secs(0))))
}

// Should fetch initial token
let _ = cache.get_or_insert_with(get_token).await.unwrap();
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);

tokio::time::sleep(Duration::from_millis(2)).await;

// Token is expired, so should fetch again
let _ = cache.get_or_insert_with(get_token).await.unwrap();
assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
}

#[tokio::test]
async fn test_min_ttl_causes_refresh() {
let cache = TokenCache {
cache: Default::default(),
min_ttl: Duration::from_secs(1),
fetch_backoff: Duration::from_millis(1),
};

static COUNTER: AtomicU32 = AtomicU32::new(0);

async fn get_token() -> Result<TemporaryToken<String>, String> {
COUNTER.fetch_add(1, Ordering::SeqCst);
Ok::<_, String>(create_token(Some(Duration::from_millis(100))))
}

// Initial fetch
let _ = cache.get_or_insert_with(get_token).await.unwrap();
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);

// Should not fetch again since not expired and within fetch_backoff
let _ = cache.get_or_insert_with(get_token).await.unwrap();
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);

tokio::time::sleep(Duration::from_millis(2)).await;

// Should fetch, since we've passed fetch_backoff
let _ = cache.get_or_insert_with(get_token).await.unwrap();
assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
}
}
25 changes: 16 additions & 9 deletions object_store/src/gcp/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ use serde::{Deserialize, Serialize};
use snafu::{OptionExt, ResultExt, Snafu};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use url::Url;

use super::credential::{AuthorizedUserSigningCredentials, InstanceSigningCredentialProvider};

const TOKEN_MIN_TTL: Duration = Duration::from_secs(4 * 60);

#[derive(Debug, Snafu)]
enum Error {
#[snafu(display("Missing bucket name"))]
Expand Down Expand Up @@ -463,13 +466,14 @@ impl GoogleCloudStorageBuilder {
)) as _
} else if let Some(credentials) = application_default_credentials.clone() {
match credentials {
ApplicationDefaultCredentials::AuthorizedUser(token) => {
Arc::new(TokenCredentialProvider::new(
ApplicationDefaultCredentials::AuthorizedUser(token) => Arc::new(
TokenCredentialProvider::new(
token,
self.client_options.client()?,
self.retry_config.clone(),
)) as _
}
)
.with_min_ttl(TOKEN_MIN_TTL),
) as _,
ApplicationDefaultCredentials::ServiceAccount(token) => {
Arc::new(TokenCredentialProvider::new(
token.token_provider()?,
Expand All @@ -479,11 +483,14 @@ impl GoogleCloudStorageBuilder {
}
}
} else {
Arc::new(TokenCredentialProvider::new(
InstanceCredentialProvider::default(),
self.client_options.metadata_client()?,
self.retry_config.clone(),
)) as _
Arc::new(
TokenCredentialProvider::new(
InstanceCredentialProvider::default(),
self.client_options.metadata_client()?,
self.retry_config.clone(),
)
.with_min_ttl(TOKEN_MIN_TTL),
) as _
};

let signing_credentials = if let Some(signing_credentials) = self.signing_credentials {
Expand Down

0 comments on commit 933d348

Please sign in to comment.