Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: fetch extensions #681

Merged
merged 15 commits into from
Mar 12, 2025
1 change: 1 addition & 0 deletions crates/duckdb/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- `Client.search_to_arrow_table` ([#634](https://github.com/stac-utils/stac-rs/pull/634))
- Conditionally disable parsing the WKB ([#635](https://github.com/stac-utils/stac-rs/pull/635))
- `Client.extensions` ([#665](https://github.com/stac-utils/stac-rs/pull/665))
- `Config.install_extensions` ([#681](https://github.com/stac-utils/stac-rs/pull/681))

### Removed

Expand Down
78 changes: 67 additions & 11 deletions crates/duckdb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,14 @@ pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
pub struct Client {
connection: Connection,

/// The client's configuration.
pub config: Config,
}

/// Configuration for a client.
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone)]
pub struct Config {
/// Whether to enable the s3 credential chain, which allows s3:// url access.
///
/// True by default.
pub use_s3_credential_chain: bool,

/// Whether to enable hive partitioning.
///
/// False by default.
Expand All @@ -100,6 +96,25 @@ pub struct Config {
///
/// Disable this to enable geopandas reading, for example.
pub convert_wkb: bool,

/// Whether to enable the S3 credential chain, which allows s3:// url access.
///
/// True by default.
pub use_s3_credential_chain: bool,

/// Whether to enable the Azure credential chain, which allows az:// url access.
///
/// True by default.
pub use_azure_credential_chain: bool,

/// Whether to directly install the httpfs extension.
pub use_httpfs: bool,

/// Whether to install extensions when creating a new connection.
pub install_extensions: bool,

/// Use a custom extension repository.
pub custom_extension_repository: Option<String>,
}

/// A SQL query.
Expand Down Expand Up @@ -167,21 +182,45 @@ impl Client {
/// use stac_duckdb::{Client, Config};
///
/// let config = Config {
/// use_s3_credential_chain: true,
/// use_hive_partitioning: true,
/// convert_wkb: true,
/// use_s3_credential_chain: true,
/// use_azure_credential_chain: true,
/// use_httpfs: true,
/// install_extensions: true,
/// custom_extension_repository: None,
/// };
/// let client = Client::with_config(config);
/// ```
pub fn with_config(config: Config) -> Result<Client> {
let connection = Connection::open_in_memory()?;
connection.execute("INSTALL spatial", [])?;
if let Some(ref custom_extension_repository) = config.custom_extension_repository {
connection.execute(
"SET custom_extension_repository = '?'",
[custom_extension_repository],
)?;
}
if config.install_extensions {
connection.execute("INSTALL spatial", [])?;
connection.execute("INSTALL icu", [])?;
}
connection.execute("LOAD spatial", [])?;
connection.execute("INSTALL icu", [])?;
connection.execute("LOAD icu", [])?;
if config.use_httpfs && config.install_extensions {
connection.execute("INSTALL httpfs", [])?;
}
if config.use_s3_credential_chain {
if config.install_extensions {
connection.execute("INSTALL aws", [])?;
}
connection.execute("CREATE SECRET (TYPE S3, PROVIDER CREDENTIAL_CHAIN)", [])?;
}
if config.use_azure_credential_chain {
if config.install_extensions {
connection.execute("INSTALL azure", [])?;
}
connection.execute("CREATE SECRET (TYPE azure, PROVIDER CREDENTIAL_CHAIN)", [])?;
}
Ok(Client { connection, config })
}

Expand Down Expand Up @@ -519,15 +558,19 @@ impl Default for Config {
fn default() -> Self {
Config {
use_hive_partitioning: false,
use_s3_credential_chain: true,
convert_wkb: true,
use_s3_credential_chain: true,
use_azure_credential_chain: true,
use_httpfs: true,
install_extensions: true,
custom_extension_repository: None,
}
}
}

#[cfg(test)]
mod tests {
use super::Client;
use super::{Client, Config};
use geo::Geometry;
use rstest::{fixture, rstest};
use stac::{Bbox, Validate};
Expand All @@ -542,6 +585,19 @@ mod tests {
Client::new().unwrap()
}

#[test]
fn no_install() {
let _mutex = MUTEX.lock().unwrap();
let config = Config {
install_extensions: false,
..Default::default()
};
let client = Client::with_config(config).unwrap();
client
.search("data/100-sentinel-2-items.parquet", Search::default())
.unwrap();
}

#[rstest]
fn search_all(client: Client) {
let item_collection = client
Expand Down