From b23b574d62512409deb89b553f2922ddf905ead2 Mon Sep 17 00:00:00 2001 From: tottoto Date: Sat, 4 Jan 2025 10:17:05 +0900 Subject: [PATCH] feat(chore): Add ConnectInfo service --- tonic/src/transport/server/mod.rs | 27 +------ tonic/src/transport/server/service/io.rs | 92 ++++++++++++++++++++--- tonic/src/transport/server/service/mod.rs | 2 +- 3 files changed, 87 insertions(+), 34 deletions(-) diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index cc2758003..086d2b7dd 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -9,6 +9,7 @@ mod tls; #[cfg(unix)] mod unix; +use service::ConnectInfoLayer; use tokio_stream::StreamExt as _; use tracing::{debug, trace}; @@ -984,7 +985,7 @@ struct MakeSvc { impl Service<&ServerIo> for MakeSvc where - IO: Connected, + IO: Connected + 'static, S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, @@ -1015,29 +1016,7 @@ where let svc = ServiceBuilder::new() .layer(BoxCloneService::layer()) - .map_request(move |mut request: Request| { - match &conn_info { - tower::util::Either::Left(inner) => { - request.extensions_mut().insert(inner.clone()); - } - tower::util::Either::Right(inner) => { - #[cfg(feature = "_tls-any")] - { - request.extensions_mut().insert(inner.clone()); - request.extensions_mut().insert(inner.get_ref().clone()); - } - - #[cfg(not(feature = "_tls-any"))] - { - // just a type check to make sure we didn't forget to - // insert this into the extensions - let _: &() = inner; - } - } - } - - request - }) + .layer(ConnectInfoLayer::new(conn_info.clone())) .service(Svc { inner: svc, trace_interceptor, diff --git a/tonic/src/transport/server/service/io.rs b/tonic/src/transport/server/service/io.rs index ed43e78a3..152867e75 100644 --- a/tonic/src/transport/server/service/io.rs +++ b/tonic/src/transport/server/service/io.rs @@ -6,6 +6,73 @@ use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; #[cfg(feature = "_tls-any")] use tokio_rustls::server::TlsStream; +use tower_layer::Layer; +use tower_service::Service; + +#[derive(Debug, Clone)] +pub(crate) struct ConnectInfoLayer { + connect_info: T, +} + +impl ConnectInfoLayer { + pub(crate) fn new(connect_info: T) -> Self { + Self { connect_info } + } +} + +impl Layer for ConnectInfoLayer +where + T: Clone, +{ + type Service = ConnectInfo; + + fn layer(&self, inner: S) -> Self::Service { + ConnectInfo::new(inner, self.connect_info.clone()) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct ConnectInfo { + inner: S, + connect_info: T, +} + +impl ConnectInfo { + fn new(inner: S, connect_info: T) -> Self { + Self { + inner, + connect_info, + } + } +} + +impl Service> for ConnectInfo> +where + S: Service>, + IO: Connected, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: http::Request) -> Self::Future { + match self.connect_info.clone() { + ServerIoConnectInfo::Io(inner) => { + req.extensions_mut().insert(inner); + } + #[cfg(feature = "_tls-any")] + ServerIoConnectInfo::TlsIo(inner) => { + req.extensions_mut().insert(inner.get_ref().clone()); + req.extensions_mut().insert(inner); + } + } + self.inner.call(req) + } +} pub(crate) enum ServerIo { Io(IO), @@ -13,14 +80,21 @@ pub(crate) enum ServerIo { TlsIo(Box>), } -use tower::util::Either; - -#[cfg(feature = "_tls-any")] -type ServerIoConnectInfo = - Either<::ConnectInfo, as Connected>::ConnectInfo>; +pub(crate) enum ServerIoConnectInfo { + Io(::ConnectInfo), + #[cfg(feature = "_tls-any")] + TlsIo( as Connected>::ConnectInfo), +} -#[cfg(not(feature = "_tls-any"))] -type ServerIoConnectInfo = Either<::ConnectInfo, ()>; +impl Clone for ServerIoConnectInfo { + fn clone(&self) -> Self { + match self { + Self::Io(io) => Self::Io(io.clone()), + #[cfg(feature = "_tls-any")] + Self::TlsIo(io) => Self::TlsIo(io.clone()), + } + } +} impl ServerIo { pub(in crate::transport) fn new_io(io: IO) -> Self { @@ -37,9 +111,9 @@ impl ServerIo { IO: Connected, { match self { - Self::Io(io) => Either::Left(io.connect_info()), + Self::Io(io) => ServerIoConnectInfo::Io(io.connect_info()), #[cfg(feature = "_tls-any")] - Self::TlsIo(io) => Either::Right(io.connect_info()), + Self::TlsIo(io) => ServerIoConnectInfo::TlsIo(io.connect_info()), } } } diff --git a/tonic/src/transport/server/service/mod.rs b/tonic/src/transport/server/service/mod.rs index b5fce0923..19b3001ac 100644 --- a/tonic/src/transport/server/service/mod.rs +++ b/tonic/src/transport/server/service/mod.rs @@ -1,5 +1,5 @@ mod io; -pub(crate) use self::io::ServerIo; +pub(crate) use self::io::{ConnectInfoLayer, ServerIo}; mod recover_error; pub(crate) use self::recover_error::RecoverError;