Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pool Postgres connections #3043

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/factor-outbound-pg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = { workspace = true }
[dependencies]
anyhow = { workspace = true }
chrono = "0.4"
deadpool-postgres = { version = "0.14", features = ["rt_tokio_1"] }
native-tls = "0.2"
postgres-native-tls = "0.5"
spin-core = { path = "../core" }
Expand Down
115 changes: 74 additions & 41 deletions crates/factor-outbound-pg/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,83 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context, Result};
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use spin_world::async_trait;
use spin_world::spin::postgres::postgres::{
self as v3, Column, DbDataType, DbValue, ParameterValue, RowSet,
};
use tokio_postgres::types::Type;
use tokio_postgres::{config::SslMode, types::ToSql, Row};
use tokio_postgres::{Client as TokioClient, NoTls, Socket};
use tokio_postgres::{config::SslMode, types::ToSql, NoTls, Row};

const CONNECTION_POOL_SIZE: usize = 64;

#[async_trait]
pub trait Client {
async fn build_client(address: &str) -> Result<Self>
where
Self: Sized;
pub trait ClientFactory: Send + Sync {
type Client: Client + Send + Sync + 'static;
fn new() -> Self;
async fn build_client(&mut self, address: &str) -> Result<Self::Client>;
}

pub struct PooledTokioClientFactory {
pools: std::collections::HashMap<String, deadpool_postgres::Pool>,
}

#[async_trait]
impl ClientFactory for PooledTokioClientFactory {
type Client = deadpool_postgres::Object;
fn new() -> Self {
Self {
pools: Default::default(),
}
}
async fn build_client(&mut self, address: &str) -> Result<Self::Client> {
let pool_entry = self.pools.entry(address.to_owned());
let pool = match pool_entry {
std::collections::hash_map::Entry::Occupied(entry) => entry.into_mut(),
std::collections::hash_map::Entry::Vacant(entry) => {
let pool = create_connection_pool(address)
.context("establishing PostgreSQL connection pool")?;
entry.insert(pool)
}
};

Ok(pool.get().await?)
}
}

fn create_connection_pool(address: &str) -> Result<deadpool_postgres::Pool> {
let config = address
.parse::<tokio_postgres::Config>()
.context("parsing Postgres connection string")?;

tracing::debug!("Build new connection: {}", address);

// TODO: This is slower but safer. Is it the right tradeoff?
// https://docs.rs/deadpool-postgres/latest/deadpool_postgres/enum.RecyclingMethod.html
let mgr_config = deadpool_postgres::ManagerConfig {
recycling_method: deadpool_postgres::RecyclingMethod::Clean,
};

let mgr = if config.get_ssl_mode() == SslMode::Disable {
deadpool_postgres::Manager::from_config(config, NoTls, mgr_config)
} else {
let builder = TlsConnector::builder();
let connector = MakeTlsConnector::new(builder.build()?);
deadpool_postgres::Manager::from_config(config, connector, mgr_config)
};

// TODO: what is our max size heuristic? Should this be passed in soe that different
// hosts can manage it according to their needs? Will a plain number suffice for
// sophisticated hosts anyway?
let pool = deadpool_postgres::Pool::builder(mgr)
.max_size(CONNECTION_POOL_SIZE)
.build()
.context("building Postgres connection pool")?;

Ok(pool)
}

#[async_trait]
pub trait Client {
async fn execute(
&self,
statement: String,
Expand All @@ -29,28 +92,7 @@ pub trait Client {
}

#[async_trait]
impl Client for TokioClient {
async fn build_client(address: &str) -> Result<Self>
where
Self: Sized,
{
let config = address.parse::<tokio_postgres::Config>()?;

tracing::debug!("Build new connection: {}", address);

if config.get_ssl_mode() == SslMode::Disable {
let (client, connection) = config.connect(NoTls).await?;
spawn_connection(connection);
Ok(client)
} else {
let builder = TlsConnector::builder();
let connector = MakeTlsConnector::new(builder.build()?);
let (client, connection) = config.connect(connector).await?;
spawn_connection(connection);
Ok(client)
}
}

impl Client for deadpool_postgres::Object {
async fn execute(
&self,
statement: String,
Expand All @@ -67,7 +109,8 @@ impl Client for TokioClient {
.map(|b| b.as_ref() as &(dyn ToSql + Sync))
.collect();

self.execute(&statement, params_refs.as_slice())
self.as_ref()
.execute(&statement, params_refs.as_slice())
.await
.map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))
}
Expand All @@ -89,6 +132,7 @@ impl Client for TokioClient {
.collect();

let results = self
.as_ref()
.query(&statement, params_refs.as_slice())
.await
.map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))?;
Expand All @@ -111,17 +155,6 @@ impl Client for TokioClient {
}
}

fn spawn_connection<T>(connection: tokio_postgres::Connection<Socket, T>)
where
T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static,
{
tokio::spawn(async move {
if let Err(e) = connection.await {
tracing::error!("Postgres connection error: {}", e);
}
});
}

fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Sync>> {
match value {
ParameterValue::Boolean(v) => Ok(Box::new(*v)),
Expand Down
25 changes: 14 additions & 11 deletions crates/factor-outbound-pg/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@ use tracing::field::Empty;
use tracing::instrument;
use tracing::Level;

use crate::client::Client;
use crate::client::{Client, ClientFactory};
use crate::InstanceState;

impl<C: Client> InstanceState<C> {
impl<CF: ClientFactory> InstanceState<CF> {
async fn open_connection<Conn: 'static>(
&mut self,
address: &str,
) -> Result<Resource<Conn>, v3::Error> {
self.connections
.push(
C::build_client(address)
self.client_factory
.write()
.await
.build_client(address)
.await
.map_err(|e| v3::Error::ConnectionFailed(format!("{e:?}")))?,
)
Expand All @@ -30,7 +33,7 @@ impl<C: Client> InstanceState<C> {
async fn get_client<Conn: 'static>(
&mut self,
connection: Resource<Conn>,
) -> Result<&C, v3::Error> {
) -> Result<&CF::Client, v3::Error> {
self.connections
.get(connection.rep())
.ok_or_else(|| v3::Error::ConnectionFailed("no connection found".into()))
Expand Down Expand Up @@ -71,8 +74,8 @@ fn v2_params_to_v3(
params.into_iter().map(|p| p.try_into()).collect()
}

impl<C: Send + Sync + Client> spin_world::spin::postgres::postgres::HostConnection
for InstanceState<C>
impl<CF: ClientFactory + Send + Sync> spin_world::spin::postgres::postgres::HostConnection
for InstanceState<CF>
{
#[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
async fn open(&mut self, address: String) -> Result<Resource<v3::Connection>, v3::Error> {
Expand Down Expand Up @@ -122,13 +125,13 @@ impl<C: Send + Sync + Client> spin_world::spin::postgres::postgres::HostConnecti
}
}

impl<C: Send> v2_types::Host for InstanceState<C> {
impl<CF: ClientFactory + Send> v2_types::Host for InstanceState<CF> {
fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
Ok(error)
}
}

impl<C: Send + Sync + Client> v3::Host for InstanceState<C> {
impl<CF: Send + Sync + ClientFactory> v3::Host for InstanceState<CF> {
fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
Ok(error)
}
Expand All @@ -152,9 +155,9 @@ macro_rules! delegate {
}};
}

impl<C: Send + Sync + Client> v2::Host for InstanceState<C> {}
impl<CF: Send + Sync + ClientFactory> v2::Host for InstanceState<CF> {}

impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
impl<CF: Send + Sync + ClientFactory> v2::HostConnection for InstanceState<CF> {
#[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
spin_factor_outbound_networking::record_address_fields(&address);
Expand Down Expand Up @@ -206,7 +209,7 @@ impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
}
}

impl<C: Send + Sync + Client> v1::Host for InstanceState<C> {
impl<CF: Send + Sync + ClientFactory> v1::Host for InstanceState<CF> {
async fn execute(
&mut self,
address: String,
Expand Down
26 changes: 15 additions & 11 deletions crates/factor-outbound-pg/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
pub mod client;
mod host;

use client::Client;
use std::sync::Arc;

use client::ClientFactory;
use spin_factor_outbound_networking::{OutboundAllowedHosts, OutboundNetworkingFactor};
use spin_factors::{
anyhow, ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder,
};
use tokio_postgres::Client as PgClient;
use tokio::sync::RwLock;

pub struct OutboundPgFactor<C = PgClient> {
_phantom: std::marker::PhantomData<C>,
pub struct OutboundPgFactor<CF = crate::client::PooledTokioClientFactory> {
_phantom: std::marker::PhantomData<CF>,
}

impl<C: Send + Sync + Client + 'static> Factor for OutboundPgFactor<C> {
impl<CF: ClientFactory + Send + Sync + 'static> Factor for OutboundPgFactor<CF> {
type RuntimeConfig = ();
type AppState = ();
type InstanceBuilder = InstanceState<C>;
type AppState = Arc<RwLock<CF>>;
type InstanceBuilder = InstanceState<CF>;

fn init<T: Send + 'static>(
&mut self,
Expand All @@ -31,7 +33,7 @@ impl<C: Send + Sync + Client + 'static> Factor for OutboundPgFactor<C> {
&self,
_ctx: ConfigureAppContext<T, Self>,
) -> anyhow::Result<Self::AppState> {
Ok(())
Ok(Arc::new(RwLock::new(CF::new())))
}

fn prepare<T: RuntimeFactors>(
Expand All @@ -43,6 +45,7 @@ impl<C: Send + Sync + Client + 'static> Factor for OutboundPgFactor<C> {
.allowed_hosts();
Ok(InstanceState {
allowed_hosts,
client_factory: ctx.app_state().clone(),
connections: Default::default(),
})
}
Expand All @@ -62,9 +65,10 @@ impl<C> OutboundPgFactor<C> {
}
}

pub struct InstanceState<C> {
pub struct InstanceState<CF: ClientFactory> {
allowed_hosts: OutboundAllowedHosts,
connections: spin_resource_table::Table<C>,
client_factory: Arc<RwLock<CF>>,
connections: spin_resource_table::Table<CF::Client>,
}

impl<C: Send + 'static> SelfInstanceBuilder for InstanceState<C> {}
impl<CF: ClientFactory + Send + 'static> SelfInstanceBuilder for InstanceState<CF> {}
Loading
Loading