Skip to content

Commit

Permalink
Cache object stores per scheme + bucket per session (#99)
Browse files Browse the repository at this point in the history
In this PR, we cache object stores per scheme (e.g. s3) + bucket (or container) combination in each Postgres session. This will reduce authentication costs by only doing it at the first time.

For s3, object_store does not perform sts assume_role to get temp token, so pg_parquet make use of aws sdk to perform it. And then configure object_store with the temp token that it fetched. pg_parquet also checks expiration of the aws tokens and fetch the temp token if it expires. (obviously if you configured temp token auth via config)

Closes #93
  • Loading branch information
aykut-bozkurt authored Jan 30, 2025
1 parent fa728df commit 3ff46d5
Show file tree
Hide file tree
Showing 10 changed files with 445 additions and 64 deletions.
1 change: 1 addition & 0 deletions .devcontainer/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ trap "echo 'Caught termination signal. Exiting...'; exit 0" SIGINT SIGTERM

# create azurite container
az storage container create -n $AZURE_TEST_CONTAINER_NAME --connection-string $AZURE_STORAGE_CONNECTION_STRING
az storage container create -n ${AZURE_TEST_CONTAINER_NAME}2 --connection-string $AZURE_STORAGE_CONNECTION_STRING

sleep infinity
3 changes: 2 additions & 1 deletion .devcontainer/minio-entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ done
# set access key and secret key
mc alias set local $AWS_ENDPOINT_URL $MINIO_ROOT_USER $MINIO_ROOT_PASSWORD

# create bucket
# create buckets
mc mb local/$AWS_S3_TEST_BUCKET
mc mb local/${AWS_S3_TEST_BUCKET}2

wait $minio_pid
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ jobs:
# create container
az storage container create -n $AZURE_TEST_CONTAINER_NAME --connection-string $AZURE_STORAGE_CONNECTION_STRING
az storage container create -n ${AZURE_TEST_CONTAINER_NAME}2 --connection-string $AZURE_STORAGE_CONNECTION_STRING
- name: Run tests
run: |
Expand Down
10 changes: 5 additions & 5 deletions src/arrow_parquet/uri_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ use pgrx::{
use url::Url;

use crate::{
arrow_parquet::parquet_writer::DEFAULT_ROW_GROUP_SIZE, object_store::create_object_store,
PG_BACKEND_TOKIO_RUNTIME,
arrow_parquet::parquet_writer::DEFAULT_ROW_GROUP_SIZE,
object_store::object_store_cache::get_or_create_object_store, PG_BACKEND_TOKIO_RUNTIME,
};

const PARQUET_OBJECT_STORE_READ_ROLE: &str = "parquet_object_store_read";
Expand Down Expand Up @@ -58,7 +58,7 @@ pub(crate) fn parquet_schema_from_uri(uri: &Url) -> SchemaDescriptor {

pub(crate) fn parquet_metadata_from_uri(uri: &Url) -> Arc<ParquetMetaData> {
let copy_from = true;
let (parquet_object_store, location) = create_object_store(uri, copy_from);
let (parquet_object_store, location) = get_or_create_object_store(uri, copy_from);

PG_BACKEND_TOKIO_RUNTIME.block_on(async {
let object_store_meta = parquet_object_store
Expand All @@ -81,7 +81,7 @@ pub(crate) fn parquet_metadata_from_uri(uri: &Url) -> Arc<ParquetMetaData> {

pub(crate) fn parquet_reader_from_uri(uri: &Url) -> ParquetRecordBatchStream<ParquetObjectReader> {
let copy_from = true;
let (parquet_object_store, location) = create_object_store(uri, copy_from);
let (parquet_object_store, location) = get_or_create_object_store(uri, copy_from);

PG_BACKEND_TOKIO_RUNTIME.block_on(async {
let object_store_meta = parquet_object_store
Expand Down Expand Up @@ -113,7 +113,7 @@ pub(crate) fn parquet_writer_from_uri(
writer_props: WriterProperties,
) -> AsyncArrowWriter<ParquetObjectWriter> {
let copy_from = false;
let (parquet_object_store, location) = create_object_store(uri, copy_from);
let (parquet_object_store, location) = get_or_create_object_store(uri, copy_from);

let parquet_object_writer = ParquetObjectWriter::new(parquet_object_store, location);

Expand Down
45 changes: 1 addition & 44 deletions src/object_store.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
use std::sync::Arc;

use object_store::{path::Path, ObjectStore, ObjectStoreScheme};
use url::Url;

use crate::{
arrow_parquet::uri_utils::uri_as_string,
object_store::{
Expand All @@ -15,42 +10,4 @@ use crate::{
pub(crate) mod aws;
pub(crate) mod azure;
pub(crate) mod local_file;

pub(crate) fn create_object_store(uri: &Url, copy_from: bool) -> (Arc<dyn ObjectStore>, Path) {
let (scheme, path) = ObjectStoreScheme::parse(uri).unwrap_or_else(|_| {
panic!(
"unrecognized uri {}. pg_parquet supports local paths, s3:// or azure:// schemes.",
uri
)
});

// object_store crate can recognize a bunch of different schemes and paths, but we only support
// local, azure, and s3 schemes with a subset of all supported paths.
match scheme {
ObjectStoreScheme::AmazonS3 => {
let storage_container = Arc::new(create_s3_object_store(uri));

(storage_container, path)
}
ObjectStoreScheme::MicrosoftAzure => {
let storage_container = Arc::new(create_azure_object_store(uri));

(storage_container, path)
}
ObjectStoreScheme::Local => {
let storage_container = Arc::new(create_local_file_object_store(uri, copy_from));

let path =
Path::from_filesystem_path(uri_as_string(uri)).unwrap_or_else(|e| panic!("{}", e));

(storage_container, path)
}
_ => {
panic!(
"unsupported scheme {} in uri {}. pg_parquet supports local paths, s3:// or azure:// schemes.",
uri.scheme(),
uri
);
}
}
}
pub(crate) mod object_store_cache;
23 changes: 18 additions & 5 deletions src/object_store/aws.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::{sync::Arc, time::SystemTime};

use aws_config::BehaviorVersion;
use aws_credential_types::provider::ProvideCredentials;
use object_store::aws::{AmazonS3, AmazonS3Builder};
use object_store::aws::AmazonS3Builder;
use url::Url;

use super::PG_BACKEND_TOKIO_RUNTIME;
use super::{object_store_cache::ObjectStoreWithExpiration, PG_BACKEND_TOKIO_RUNTIME};

// create_s3_object_store creates an AmazonS3 object store with the given bucket name.
// It is configured by environment variables and aws config files as fallback method.
Expand All @@ -19,7 +21,7 @@ use super::PG_BACKEND_TOKIO_RUNTIME;
// - AWS_CONFIG_FILE (env var only)
// - AWS_PROFILE (env var only)
// - AWS_ALLOW_HTTP (env var only, object_store specific)
pub(crate) fn create_s3_object_store(uri: &Url) -> AmazonS3 {
pub(crate) fn create_s3_object_store(uri: &Url) -> ObjectStoreWithExpiration {
let bucket_name = parse_s3_bucket(uri).unwrap_or_else(|| {
panic!("unsupported s3 uri: {}", uri);
});
Expand Down Expand Up @@ -58,10 +60,17 @@ pub(crate) fn create_s3_object_store(uri: &Url) -> AmazonS3 {
aws_s3_builder = aws_s3_builder.with_region(region);
}

aws_s3_builder.build().unwrap_or_else(|e| panic!("{}", e))
let object_store = aws_s3_builder.build().unwrap_or_else(|e| panic!("{}", e));

let expire_at = aws_s3_config.expire_at;

ObjectStoreWithExpiration {
object_store: Arc::new(object_store),
expire_at,
}
}

fn parse_s3_bucket(uri: &Url) -> Option<String> {
pub(crate) fn parse_s3_bucket(uri: &Url) -> Option<String> {
let host = uri.host_str()?;

// s3(a)://{bucket}/key
Expand Down Expand Up @@ -98,6 +107,7 @@ struct AwsS3Config {
access_key_id: Option<String>,
secret_access_key: Option<String>,
session_token: Option<String>,
expire_at: Option<SystemTime>,
endpoint_url: Option<String>,
allow_http: bool,
}
Expand All @@ -121,6 +131,7 @@ impl AwsS3Config {
let mut access_key_id = None;
let mut secret_access_key = None;
let mut session_token = None;
let mut expire_at = None;

if let Some(credential_provider) = sdk_config.credentials_provider() {
if let Ok(credentials) = PG_BACKEND_TOKIO_RUNTIME
Expand All @@ -129,6 +140,7 @@ impl AwsS3Config {
access_key_id = Some(credentials.access_key_id().to_string());
secret_access_key = Some(credentials.secret_access_key().to_string());
session_token = credentials.session_token().map(|t| t.to_string());
expire_at = credentials.expiry();
}
}

Expand All @@ -141,6 +153,7 @@ impl AwsS3Config {
access_key_id,
secret_access_key,
session_token,
expire_at,
endpoint_url,
allow_http,
}
Expand Down
20 changes: 16 additions & 4 deletions src/object_store/azure.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::sync::Arc;

use azure_storage::{ConnectionString, EndpointProtocol};
use home::home_dir;
use ini::Ini;
use object_store::azure::{AzureConfigKey, MicrosoftAzure, MicrosoftAzureBuilder};
use object_store::azure::{AzureConfigKey, MicrosoftAzureBuilder};
use url::Url;

use super::object_store_cache::ObjectStoreWithExpiration;

// create_azure_object_store creates a MicrosoftAzure object store with the given container name.
// It is configured by environment variables and azure config files as fallback method.
// We need to read the config files to make the fallback method work since object_store
Expand All @@ -16,7 +20,7 @@ use url::Url;
// - AZURE_CONFIG_FILE (env var only, object_store specific)
// - AZURE_STORAGE_ENDPOINT (env var only, object_store specific)
// - AZURE_ALLOW_HTTP (env var only, object_store specific)
pub(crate) fn create_azure_object_store(uri: &Url) -> MicrosoftAzure {
pub(crate) fn create_azure_object_store(uri: &Url) -> ObjectStoreWithExpiration {
let container_name = parse_azure_blob_container(uri).unwrap_or_else(|| {
panic!("unsupported azure blob storage uri: {}", uri);
});
Expand Down Expand Up @@ -63,10 +67,18 @@ pub(crate) fn create_azure_object_store(uri: &Url) -> MicrosoftAzure {
azure_builder = azure_builder.with_client_secret(client_secret);
}

azure_builder.build().unwrap_or_else(|e| panic!("{}", e))
let object_store = azure_builder.build().unwrap_or_else(|e| panic!("{}", e));

// object store handles refreshing bearer token, so we do not need to handle expiry here
let expire_at = None;

ObjectStoreWithExpiration {
object_store: Arc::new(object_store),
expire_at,
}
}

fn parse_azure_blob_container(uri: &Url) -> Option<String> {
pub(crate) fn parse_azure_blob_container(uri: &Url) -> Option<String> {
let host = uri.host_str()?;

// az(ure)://{container}/key
Expand Down
17 changes: 14 additions & 3 deletions src/object_store/local_file.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use std::sync::Arc;

use object_store::local::LocalFileSystem;
use url::Url;

use super::uri_as_string;
use super::{object_store_cache::ObjectStoreWithExpiration, uri_as_string};

// create_local_file_object_store creates a LocalFileSystem object store with the given path.
pub(crate) fn create_local_file_object_store(uri: &Url, copy_from: bool) -> LocalFileSystem {
pub(crate) fn create_local_file_object_store(
uri: &Url,
copy_from: bool,
) -> ObjectStoreWithExpiration {
let path = uri_as_string(uri);

if !copy_from {
Expand All @@ -17,5 +22,11 @@ pub(crate) fn create_local_file_object_store(uri: &Url, copy_from: bool) -> Loca
.unwrap_or_else(|e| panic!("{}", e));
}

LocalFileSystem::new()
let object_store = LocalFileSystem::new();
let expire_at = None;

ObjectStoreWithExpiration {
object_store: Arc::new(object_store),
expire_at,
}
}
Loading

0 comments on commit 3ff46d5

Please sign in to comment.