diff --git a/Cargo.toml b/Cargo.toml index d0db8e6a9..ae2bceebb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,3 +28,6 @@ members = [ "tests/default_stubs", ] resolver = "2" + +[patch.crates-io] +hyper-util = { git = "https://github.com/hyperium/hyper-util.git", rev = "refs/pull/136/head" } diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index d41f4ab8a..1248a45ee 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -37,7 +37,7 @@ server = [ "dep:async-stream", "dep:h2", "dep:hyper", "hyper?/server", - "dep:hyper-util", "hyper-util?/service", "hyper-util?/server-auto", + "dep:hyper-util", "hyper-util?/service", "hyper-util?/server-auto", "hyper-util?/server-graceful", "dep:socket2", "dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time", "tokio-stream/net", diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 31700ee11..e1fd29895 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -16,6 +16,7 @@ use crate::service::Routes; pub use conn::{Connected, TcpConnectInfo}; use hyper_util::{ rt::{TokioExecutor, TokioIo, TokioTimer}, + server::graceful::GracefulShutdown, service::TowerToHyperService, }; #[cfg(feature = "tls")] @@ -562,10 +563,7 @@ impl Server { builder }; - let (signal_tx, signal_rx) = tokio::sync::watch::channel(()); - let signal_tx = Arc::new(signal_tx); - - let graceful = signal.is_some(); + let graceful = signal.is_some().then(GracefulShutdown::new); let mut sig = pin!(Fuse { inner: signal }); let mut incoming = pin!(incoming); @@ -602,21 +600,13 @@ impl Server { let hyper_io = TokioIo::new(io); let hyper_svc = TowerToHyperService::new(req_svc); - serve_connection(hyper_io, hyper_svc, server.clone(), graceful.then(|| signal_rx.clone())); + serve_connection(hyper_io, hyper_svc, server.clone(), graceful.clone()); } } } - if graceful { - let _ = signal_tx.send(()); - drop(signal_rx); - trace!( - "waiting for {} connections to close", - signal_tx.receiver_count() - ); - - // Wait for all connections to close - signal_tx.closed().await; + if let Some(graceful) = graceful { + graceful.shutdown().await; } Ok(()) @@ -629,7 +619,7 @@ fn serve_connection( hyper_io: IO, hyper_svc: S, builder: ConnectionBuilder, - mut watcher: Option>, + graceful: Option, ) where IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, S: HyperService, Response = Response> + Clone + Send + 'static, @@ -638,28 +628,20 @@ fn serve_connection( { tokio::spawn(async move { { - let mut sig = pin!(Fuse { - inner: watcher.as_mut().map(|w| w.changed()), - }); + let conn = builder.serve_connection(hyper_io, hyper_svc); - let mut conn = pin!(builder.serve_connection(hyper_io, hyper_svc)); + let result = if let Some(graceful) = graceful { + let conn = graceful.watch(conn); + conn.await + } else { + conn.await + }; - loop { - tokio::select! { - rv = &mut conn => { - if let Err(err) = rv { - debug!("failed serving connection: {:#}", err); - } - break; - }, - _ = &mut sig => { - conn.as_mut().graceful_shutdown(); - } - } + if let Err(err) = result { + debug!("failed serving connection: {:#}", err); } } - drop(watcher); trace!("connection closed"); }); }