Skip to content

Commit 4b11e63

Browse files
authored
Merge pull request superfly#373 from superfly/gorbak/track-alive-api-conn
Track active TCP streams in the user facing API
2 parents dbb3fc4 + c512303 commit 4b11e63

File tree

16 files changed

+375
-38
lines changed

16 files changed

+375
-38
lines changed

Cargo.lock

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ uhlc = { version = "0.7", features = ["defmt"] }
8585
uuid = { version = "1.3.1", features = ["v4", "serde"] }
8686
webpki = { version = "0.22.0", features = ["std"] }
8787
http = { version = "0.2.9" }
88+
lazy_static = "1.5.0"
8889

8990
[profile.release]
9091
debug = 1

crates/corro-agent/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ uhlc = { workspace = true }
6161
uuid = { workspace = true }
6262
corro-pg = { path = "../corro-pg" }
6363
indexmap = { workspace = true }
64+
pin-project-lite = { workspace = true }
6465
governor.workspace = true
6566

6667
[dev-dependencies]

crates/corro-agent/src/api/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pub mod peer;
22
pub mod public;
3+
pub mod utils;

crates/corro-agent/src/api/public/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::{
55
time::{Duration, Instant},
66
};
77

8+
use crate::api::utils::CountedBody;
89
use antithesis_sdk::assert_sometimes;
910
use axum::{extract::ConnectInfo, response::IntoResponse, Extension};
1011
use bytes::{BufMut, BytesMut};
@@ -18,6 +19,7 @@ use corro_types::{
1819
base::CrsqlDbVersion,
1920
broadcast::Timestamp,
2021
change::{insert_local_changes, InsertChangesInfo, SqliteValue},
22+
persistent_gauge,
2123
schema::{apply_schema, parse_sql},
2224
sqlite::SqlitePoolError,
2325
};
@@ -472,7 +474,9 @@ pub async fn api_v1_queries(
472474
axum::extract::Query(params): axum::extract::Query<TimeoutParams>,
473475
axum::extract::Json(stmt): axum::extract::Json<Statement>,
474476
) -> impl IntoResponse {
475-
let (mut tx, body) = hyper::Body::channel();
477+
let (mut tx, body) = CountedBody::channel(
478+
persistent_gauge!("corro.api.active.streams", "source" => "queries", "protocol" => "http"),
479+
);
476480

477481
counter!("corro.api.queries.count").increment(1);
478482
// TODO: timeout on data send instead of infinitely waiting for channel space.

crates/corro-agent/src/api/public/pubsub.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use std::{collections::HashMap, io::Write, sync::Arc, time::Duration};
22

3+
use crate::api::utils::CountedBody;
34
use axum::{http::StatusCode, response::IntoResponse, Extension};
45
use bytes::{BufMut, Bytes, BytesMut};
56
use compact_str::{format_compact, ToCompactString};
7+
use corro_types::persistent_gauge;
68
use corro_types::updates::Handle;
79
use corro_types::{
810
agent::Agent,
@@ -50,7 +52,7 @@ async fn sub_by_id(
5052
params: SubParams,
5153
bcast_cache: &SharedMatcherBroadcastCache,
5254
tripwire: Tripwire,
53-
) -> hyper::Response<hyper::Body> {
55+
) -> impl IntoResponse {
5456
let matcher_rx = bcast_cache.read().await.get(&id).and_then(|tx| {
5557
subs.get(&id).map(|matcher| {
5658
debug!("found matcher by id {id}");
@@ -99,7 +101,9 @@ async fn sub_by_id(
99101
let query_hash = matcher.hash().to_owned();
100102
tokio::spawn(catch_up_sub(matcher, params, rx, evt_tx));
101103

102-
let (tx, body) = hyper::Body::channel();
104+
let (tx, body) = CountedBody::channel(
105+
persistent_gauge!("corro.api.active.streams", "source" => "subscriptions", "protocol" => "http"),
106+
);
103107

104108
tokio::spawn(forward_bytes_to_body_sender(id, evt_rx, tx, tripwire));
105109

@@ -302,7 +306,7 @@ impl MatcherUpsertError {
302306
}
303307
}
304308

305-
impl From<MatcherUpsertError> for hyper::Response<hyper::Body> {
309+
impl From<MatcherUpsertError> for hyper::Response<CountedBody<hyper::Body>> {
306310
fn from(value: MatcherUpsertError) -> Self {
307311
hyper::Response::builder()
308312
.status(value.status_code())
@@ -676,7 +680,7 @@ pub async fn api_v1_subs(
676680
) -> impl IntoResponse {
677681
let stmt = match expand_sql(&agent, &stmt).await {
678682
Ok(stmt) => stmt,
679-
Err(e) => return hyper::Response::<hyper::Body>::from(e),
683+
Err(e) => return hyper::Response::<CountedBody<hyper::Body>>::from(e),
680684
};
681685

682686
info!("Received subscription request for query: {stmt}");
@@ -695,10 +699,14 @@ pub async fn api_v1_subs(
695699

696700
let (handle, maybe_created) = match upsert_res {
697701
Ok(res) => res,
698-
Err(e) => return hyper::Response::<hyper::Body>::from(MatcherUpsertError::from(e)),
702+
Err(e) => {
703+
return hyper::Response::<CountedBody<hyper::Body>>::from(MatcherUpsertError::from(e))
704+
}
699705
};
700706

701-
let (tx, body) = hyper::Body::channel();
707+
let (tx, body) = CountedBody::channel(
708+
persistent_gauge!("corro.api.active.streams", "source" => "subscriptions", "protocol" => "http"),
709+
);
702710
let (forward_tx, forward_rx) = mpsc::channel(10240);
703711

704712
tokio::spawn(forward_bytes_to_body_sender(
@@ -720,7 +728,7 @@ pub async fn api_v1_subs(
720728
.await
721729
{
722730
Ok(id) => id,
723-
Err(e) => return hyper::Response::<hyper::Body>::from(e),
731+
Err(e) => return hyper::Response::<CountedBody<hyper::Body>>::from(e),
724732
};
725733

726734
hyper::Response::builder()

crates/corro-agent/src/api/public/update.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use compact_str::ToCompactString;
77
use corro_types::{
88
agent::Agent,
99
api::NotifyEvent,
10+
persistent_gauge,
1011
updates::{Handle, UpdateCreated, UpdateHandle, UpdatesManager},
1112
};
1213
use futures::future::poll_fn;
@@ -18,7 +19,7 @@ use tracing::{debug, info, warn};
1819
use tripwire::Tripwire;
1920
use uuid::Uuid;
2021

21-
use crate::api::public::pubsub::MatcherUpsertError;
22+
use crate::api::{public::pubsub::MatcherUpsertError, utils::CountedBody};
2223

2324
pub type UpdateBroadcastCache = HashMap<Uuid, broadcast::Sender<Bytes>>;
2425
pub type SharedUpdateBroadcastCache = Arc<TokioRwLock<UpdateBroadcastCache>>;
@@ -48,16 +49,20 @@ pub async fn api_v1_updates(
4849

4950
let (handle, maybe_created) = match upsert_res {
5051
Ok(res) => res,
51-
Err(e) => return hyper::Response::<hyper::Body>::from(MatcherUpsertError::from(e)),
52+
Err(e) => {
53+
return hyper::Response::<CountedBody<hyper::Body>>::from(MatcherUpsertError::from(e))
54+
}
5255
};
5356

54-
let (tx, body) = hyper::Body::channel();
57+
let (tx, body) = CountedBody::channel(
58+
persistent_gauge!("corro.api.active.streams", "source" => "updates", "protocol" => "http"),
59+
);
5560
// let (forward_tx, forward_rx) = mpsc::channel(10240);
5661

5762
let (update_id, sub_rx) =
5863
match upsert_update(handle.clone(), maybe_created, updates, &mut bcast_write).await {
5964
Ok(id) => id,
60-
Err(e) => return hyper::Response::<hyper::Body>::from(e),
65+
Err(e) => return hyper::Response::<CountedBody<hyper::Body>>::from(e),
6166
};
6267

6368
tokio::spawn(forward_update_bytes_to_body_sender(
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
use corro_types::gauge::PersistentGauge;
2+
use hyper::body::{Body, HttpBody, Sender, SizeHint};
3+
use pin_project_lite::pin_project;
4+
use std::pin::Pin;
5+
use std::task::{Context, Poll};
6+
7+
pin_project! {
8+
pub struct CountedBody<B: HttpBody> {
9+
#[pin]
10+
body: B,
11+
gauge: Option<PersistentGauge>,
12+
}
13+
14+
impl<B: HttpBody> PinnedDrop for CountedBody<B> {
15+
fn drop(this: Pin<&mut Self>) {
16+
if let Some(gauge) = &this.gauge {
17+
gauge.decrement(1.0);
18+
}
19+
}
20+
}
21+
}
22+
23+
impl<B: HttpBody> CountedBody<B> {
24+
fn new(body: B, gauge: Option<PersistentGauge>) -> Self {
25+
if let Some(gauge) = &gauge {
26+
gauge.increment(1.0);
27+
}
28+
Self { body, gauge }
29+
}
30+
}
31+
32+
impl CountedBody<Body> {
33+
// Channel bodies need to be counted as they can be long lived
34+
pub fn channel(gauge: PersistentGauge) -> (Sender, Self) {
35+
let (tx, body) = hyper::Body::channel();
36+
(tx, Self::new(body, Some(gauge)))
37+
}
38+
}
39+
40+
impl<B: HttpBody> HttpBody for CountedBody<B> {
41+
type Data = B::Data;
42+
type Error = B::Error;
43+
44+
fn poll_data(
45+
self: Pin<&mut Self>,
46+
cx: &mut Context<'_>,
47+
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
48+
let this = self.project();
49+
this.body.poll_data(cx)
50+
}
51+
52+
fn poll_trailers(
53+
self: Pin<&mut Self>,
54+
cx: &mut Context<'_>,
55+
) -> Poll<Result<Option<hyper::HeaderMap>, Self::Error>> {
56+
let this = self.project();
57+
this.body.poll_trailers(cx)
58+
}
59+
60+
fn is_end_stream(&self) -> bool {
61+
self.body.is_end_stream()
62+
}
63+
fn size_hint(&self) -> SizeHint {
64+
self.body.size_hint()
65+
}
66+
}
67+
68+
// If the underlying body can be constructed from some simple type
69+
// then we can implement From<T> for CountedBody<B>
70+
// No need to track their count as they are short lived
71+
pub trait SimpleBody {}
72+
impl SimpleBody for Vec<u8> {}
73+
impl SimpleBody for &'static [u8] {}
74+
impl SimpleBody for String {}
75+
impl SimpleBody for &'static str {}
76+
77+
impl<B, T> From<T> for CountedBody<B>
78+
where
79+
B: HttpBody + From<T>,
80+
T: SimpleBody,
81+
{
82+
fn from(value: T) -> Self {
83+
Self::new(value.into(), None)
84+
}
85+
}

crates/corro-pg/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ sqlparser = { version = "0.39.0" }
3333
chrono = { version = "0.4.31" }
3434
socket2 = { version = "0.5" }
3535
tokio-rustls = "0.24.1"
36+
pin-project-lite = { workspace = true }
3637

3738
[dev-dependencies]
3839
corro-tests = { path = "../corro-tests" }

crates/corro-pg/src/lib.rs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pub mod sql_state;
2+
pub mod utils;
23
mod vtab;
34

45
use std::{
@@ -19,6 +20,7 @@ use corro_types::{
1920
broadcast::{broadcast_changes, Timestamp},
2021
change::{insert_local_changes, InsertChangesInfo},
2122
config::PgConfig,
23+
persistent_gauge,
2224
schema::{parse_sql, Column, Schema, SchemaError, SqliteType, Table},
2325
sqlite::CrConn,
2426
};
@@ -72,6 +74,7 @@ use tripwire::{Outcome, PreemptibleFutureExt, TimeoutFutureExt, Tripwire};
7274

7375
use crate::{
7476
sql_state::SqlState,
77+
utils::CountedTcpStream,
7578
vtab::{
7679
pg_class::PgClassTable,
7780
pg_database::{PgDatabase, PgDatabaseTable},
@@ -553,30 +556,36 @@ pub async fn start(
553556
let server = TcpListener::bind(pg.bind_addr).await?;
554557
let (tls_acceptor, ssl_required) = setup_tls(pg).await?;
555558
let local_addr = server.local_addr()?;
559+
let conn_gauge = persistent_gauge!("corro.api.active.streams",
560+
"source" => "postgres",
561+
"protocol" => "pg",
562+
"readonly" => readonly.to_string(),
563+
);
556564

557565
tokio::spawn(async move {
558566
loop {
559-
let (mut conn, remote_addr) = match server.accept().preemptible(&mut tripwire).await {
567+
let (tcp_conn, remote_addr) = match server.accept().preemptible(&mut tripwire).await {
560568
Outcome::Completed(res) => res?,
561569
Outcome::Preempted(_) => break,
562570
};
571+
let mut conn = CountedTcpStream::wrap(tcp_conn, conn_gauge.clone());
563572
let tls_acceptor = tls_acceptor.clone();
564573
debug!("Accepted a PostgreSQL connection (from: {remote_addr})");
565574

566-
counter!("corro.api.connection.count", "protocol" => "pg").increment(1);
575+
counter!("corro.api.connection.count", "protocol" => "pg", "readonly" => readonly.to_string()).increment(1);
567576

568577
let agent = agent.clone();
569578
tokio::spawn(async move {
570-
conn.set_nodelay(true)?;
579+
conn.stream.set_nodelay(true)?;
571580
{
572-
let sock = SockRef::from(&conn);
581+
let sock = SockRef::from(&conn.stream);
573582
let ka = TcpKeepalive::new()
574583
.with_time(Duration::from_secs(10))
575584
.with_interval(Duration::from_secs(10))
576585
.with_retries(4);
577586
sock.set_tcp_keepalive(&ka)?;
578587
}
579-
let is_sslrequest = peek_for_sslrequest(&mut conn).await?;
588+
let is_sslrequest = peek_for_sslrequest(&mut conn.stream).await?;
580589

581590
// reject non-ssl connections if ssl is required (client cert auth)
582591
if ssl_required && !is_sslrequest {
@@ -586,7 +595,7 @@ pub async fn start(
586595

587596
let (mut framed, secured) = match (tls_acceptor, is_sslrequest) {
588597
(Some(tls_acceptor), true) => {
589-
conn.write_all(b"S").await?;
598+
conn.stream.write_all(b"S").await?;
590599
let tls_conn = tls_acceptor.accept(conn).await?;
591600
(
592601
Framed::new(
@@ -600,7 +609,7 @@ pub async fn start(
600609
}
601610
(_, is_sslreq) => {
602611
if is_sslreq {
603-
conn.write_all(b"N").await?;
612+
conn.stream.write_all(b"N").await?;
604613
}
605614
(
606615
Framed::new(

0 commit comments

Comments
 (0)