From 815df41fe03be9a067480366cc9ef0bef0f81dee Mon Sep 17 00:00:00 2001 From: tottoto Date: Thu, 11 Jul 2024 00:52:54 +0900 Subject: [PATCH] chore(server): Vendor hyper-util graceful shutdown feature --- tonic/src/transport/server/graceful.rs | 250 +++++++++++++++++++++++++ tonic/src/transport/server/mod.rs | 53 ++---- 2 files changed, 269 insertions(+), 34 deletions(-) create mode 100644 tonic/src/transport/server/graceful.rs diff --git a/tonic/src/transport/server/graceful.rs b/tonic/src/transport/server/graceful.rs new file mode 100644 index 000000000..9c63017e1 --- /dev/null +++ b/tonic/src/transport/server/graceful.rs @@ -0,0 +1,250 @@ +// From https://github.com/hyperium/hyper-util/blob/7afb1ed5337c0689d7341e09d31578f1fcffc8af/src/server/graceful.rs, +// implements Clone for GracefulShutdown. + +use std::{ + fmt::{self, Debug}, + future::Future, + pin::Pin, + task::{self, Poll}, +}; + +use pin_project::pin_project; +use tokio::sync::watch; + +/// A graceful shutdown utility +#[derive(Clone)] +pub(super) struct GracefulShutdown { + tx: watch::Sender<()>, +} + +impl GracefulShutdown { + /// Create a new graceful shutdown helper. + pub(super) fn new() -> Self { + let (tx, _) = watch::channel(()); + Self { tx } + } + + /// Wrap a future for graceful shutdown watching. + pub(super) fn watch(&self, conn: C) -> impl Future { + let mut rx = self.tx.subscribe(); + GracefulConnectionFuture::new(conn, async move { + let _ = rx.changed().await; + // hold onto the rx until the watched future is completed + rx + }) + } + + /// Signal shutdown for all watched connections. + /// + /// This returns a `Future` which will complete once all watched + /// connections have shutdown. + pub(super) async fn shutdown(self) { + let Self { tx } = self; + + // signal all the watched futures about the change + let _ = tx.send(()); + // and then wait for all of them to complete + tx.closed().await; + } +} + +impl Debug for GracefulShutdown { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GracefulShutdown").finish() + } +} + +impl Default for GracefulShutdown { + fn default() -> Self { + Self::new() + } +} + +#[pin_project] +struct GracefulConnectionFuture { + #[pin] + conn: C, + #[pin] + cancel: F, + #[pin] + // If cancelled, this is held until the inner conn is done. + cancelled_guard: Option, +} + +impl GracefulConnectionFuture { + fn new(conn: C, cancel: F) -> Self { + Self { + conn, + cancel, + cancelled_guard: None, + } + } +} + +impl Debug for GracefulConnectionFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GracefulConnectionFuture").finish() + } +} + +impl Future for GracefulConnectionFuture +where + C: GracefulConnection, + F: Future, +{ + type Output = C::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + let mut this = self.project(); + if this.cancelled_guard.is_none() { + if let Poll::Ready(guard) = this.cancel.poll(cx) { + this.cancelled_guard.set(Some(guard)); + this.conn.as_mut().graceful_shutdown(); + } + } + this.conn.poll(cx) + } +} + +/// An internal utility trait as an umbrella target for all (hyper) connection +/// types that the [`GracefulShutdown`] can watch. +pub(super) trait GracefulConnection: + Future> + private::Sealed +{ + /// The error type returned by the connection when used as a future. + type Error; + + /// Start a graceful shutdown process for this connection. + fn graceful_shutdown(self: Pin<&mut Self>); +} + +impl GracefulConnection for hyper::server::conn::http1::Connection +where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, +{ + type Error = hyper::Error; + + fn graceful_shutdown(self: Pin<&mut Self>) { + hyper::server::conn::http1::Connection::graceful_shutdown(self); + } +} + +impl GracefulConnection for hyper::server::conn::http2::Connection +where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, +{ + type Error = hyper::Error; + + fn graceful_shutdown(self: Pin<&mut Self>) { + hyper::server::conn::http2::Connection::graceful_shutdown(self); + } +} + +impl<'a, I, B, S, E> GracefulConnection for hyper_util::server::conn::auto::Connection<'a, I, S, E> +where + S: hyper::service::Service, Response = http::Response>, + S::Error: Into>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, +{ + type Error = Box; + + fn graceful_shutdown(self: Pin<&mut Self>) { + hyper_util::server::conn::auto::Connection::graceful_shutdown(self); + } +} + +impl<'a, I, B, S, E> GracefulConnection + for hyper_util::server::conn::auto::UpgradeableConnection<'a, I, S, E> +where + S: hyper::service::Service, Response = http::Response>, + S::Error: Into>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, +{ + type Error = Box; + + fn graceful_shutdown(self: Pin<&mut Self>) { + hyper_util::server::conn::auto::UpgradeableConnection::graceful_shutdown(self); + } +} + +mod private { + pub(crate) trait Sealed {} + + impl Sealed for hyper::server::conn::http1::Connection + where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + { + } + + impl Sealed for hyper::server::conn::http1::UpgradeableConnection + where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + { + } + + impl Sealed for hyper::server::conn::http2::Connection + where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, + { + } + + impl<'a, I, B, S, E> Sealed for hyper_util::server::conn::auto::Connection<'a, I, S, E> + where + S: hyper::service::Service< + http::Request, + Response = http::Response, + >, + S::Error: Into>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, + { + } + + impl<'a, I, B, S, E> Sealed for hyper_util::server::conn::auto::UpgradeableConnection<'a, I, S, E> + where + S: hyper::service::Service< + http::Request, + Response = http::Response, + >, + S::Error: Into>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, + { + } +} diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 9d687a9a3..470c53aa8 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -1,6 +1,7 @@ //! Server implementation and builder. mod conn; +mod graceful; mod incoming; mod service; #[cfg(feature = "tls")] @@ -36,7 +37,10 @@ pub use incoming::TcpIncoming; #[cfg(feature = "tls")] use crate::transport::Error; -use self::service::{RecoverError, ServerIo}; +use self::{ + graceful::GracefulShutdown, + service::{RecoverError, ServerIo}, +}; use super::service::GrpcTimeout; use crate::body::{boxed, BoxBody}; use crate::server::NamedService; @@ -561,10 +565,7 @@ impl Server { builder }; - let (signal_tx, signal_rx) = tokio::sync::watch::channel(()); - let signal_tx = Arc::new(signal_tx); - - let graceful = signal.is_some(); + let graceful = signal.is_some().then(GracefulShutdown::new); let mut sig = pin!(Fuse { inner: signal }); let mut incoming = pin!(incoming); @@ -600,21 +601,13 @@ impl Server { let hyper_io = TokioIo::new(io); let hyper_svc = TowerToHyperService::new(req_svc.map_request(|req: Request| req.map(boxed))); - serve_connection(hyper_io, hyper_svc, server.clone(), graceful.then(|| signal_rx.clone())); + serve_connection(hyper_io, hyper_svc, server.clone(), graceful.clone()); } } } - if graceful { - let _ = signal_tx.send(()); - drop(signal_rx); - trace!( - "waiting for {} connections to close", - signal_tx.receiver_count() - ); - - // Wait for all connections to close - signal_tx.closed().await; + if let Some(graceful) = graceful { + graceful.shutdown().await; } Ok(()) @@ -627,7 +620,7 @@ fn serve_connection( hyper_io: IO, hyper_svc: S, builder: ConnectionBuilder, - mut watcher: Option>, + graceful: Option, ) where B: http_body::Body + Send + 'static, B::Data: Send, @@ -640,28 +633,20 @@ fn serve_connection( { tokio::spawn(async move { { - let mut sig = pin!(Fuse { - inner: watcher.as_mut().map(|w| w.changed()), - }); + let conn = builder.serve_connection(hyper_io, hyper_svc); - let mut conn = pin!(builder.serve_connection(hyper_io, hyper_svc)); + let result = if let Some(graceful) = graceful { + let conn = graceful.watch(conn); + conn.await + } else { + conn.await + }; - loop { - tokio::select! { - rv = &mut conn => { - if let Err(err) = rv { - debug!("failed serving connection: {:#}", err); - } - break; - }, - _ = &mut sig => { - conn.as_mut().graceful_shutdown(); - } - } + if let Err(err) = result { + debug!("failed serving connection: {:#}", err); } } - drop(watcher); trace!("connection closed"); }); }