Skip to content

Commit

Permalink
chore(server): Add ConnectInfo service (#2118)
Browse files Browse the repository at this point in the history
  • Loading branch information
tottoto authored Jan 4, 2025
1 parent 5e9a5bc commit ce043fd
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 35 deletions.
28 changes: 3 additions & 25 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub use incoming::TcpIncoming;
#[cfg(feature = "_tls-any")]
use crate::transport::Error;

use self::service::{RecoverError, ServerIo};
use self::service::{ConnectInfoLayer, RecoverError, ServerIo};
use super::service::GrpcTimeout;
use crate::body::Body;
use crate::server::NamedService;
Expand Down Expand Up @@ -984,7 +984,7 @@ struct MakeSvc<S, IO> {

impl<S, ResBody, IO> Service<&ServerIo<IO>> for MakeSvc<S, IO>
where
IO: Connected,
IO: Connected + 'static,
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<crate::BoxError> + Send,
Expand Down Expand Up @@ -1015,29 +1015,7 @@ where

let svc = ServiceBuilder::new()
.layer(BoxCloneService::layer())
.map_request(move |mut request: Request<Body>| {
match &conn_info {
tower::util::Either::Left(inner) => {
request.extensions_mut().insert(inner.clone());
}
tower::util::Either::Right(inner) => {
#[cfg(feature = "_tls-any")]
{
request.extensions_mut().insert(inner.clone());
request.extensions_mut().insert(inner.get_ref().clone());
}

#[cfg(not(feature = "_tls-any"))]
{
// just a type check to make sure we didn't forget to
// insert this into the extensions
let _: &() = inner;
}
}
}

request
})
.layer(ConnectInfoLayer::new(conn_info.clone()))
.service(Svc {
inner: svc,
trace_interceptor,
Expand Down
92 changes: 83 additions & 9 deletions tonic/src/transport/server/service/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,95 @@ use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[cfg(feature = "_tls-any")]
use tokio_rustls::server::TlsStream;
use tower_layer::Layer;
use tower_service::Service;

#[derive(Debug, Clone)]
pub(crate) struct ConnectInfoLayer<T> {
connect_info: T,
}

impl<T> ConnectInfoLayer<T> {
pub(crate) fn new(connect_info: T) -> Self {
Self { connect_info }
}
}

impl<S, T> Layer<S> for ConnectInfoLayer<T>
where
T: Clone,
{
type Service = ConnectInfo<S, T>;

fn layer(&self, inner: S) -> Self::Service {
ConnectInfo::new(inner, self.connect_info.clone())
}
}

#[derive(Debug, Clone)]
pub(crate) struct ConnectInfo<S, T> {
inner: S,
connect_info: T,
}

impl<S, T> ConnectInfo<S, T> {
fn new(inner: S, connect_info: T) -> Self {
Self {
inner,
connect_info,
}
}
}

impl<S, IO, ReqBody> Service<http::Request<ReqBody>> for ConnectInfo<S, ServerIoConnectInfo<IO>>
where
S: Service<http::Request<ReqBody>>,
IO: Connected,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
match self.connect_info.clone() {
ServerIoConnectInfo::Io(inner) => {
req.extensions_mut().insert(inner);
}
#[cfg(feature = "_tls-any")]
ServerIoConnectInfo::TlsIo(inner) => {
req.extensions_mut().insert(inner.get_ref().clone());
req.extensions_mut().insert(inner);
}
}
self.inner.call(req)
}
}

pub(crate) enum ServerIo<IO> {
Io(IO),
#[cfg(feature = "_tls-any")]
TlsIo(Box<TlsStream<IO>>),
}

use tower::util::Either;

#[cfg(feature = "_tls-any")]
type ServerIoConnectInfo<IO> =
Either<<IO as Connected>::ConnectInfo, <TlsStream<IO> as Connected>::ConnectInfo>;
pub(crate) enum ServerIoConnectInfo<IO: Connected> {
Io(<IO as Connected>::ConnectInfo),
#[cfg(feature = "_tls-any")]
TlsIo(<TlsStream<IO> as Connected>::ConnectInfo),
}

#[cfg(not(feature = "_tls-any"))]
type ServerIoConnectInfo<IO> = Either<<IO as Connected>::ConnectInfo, ()>;
impl<IO: Connected> Clone for ServerIoConnectInfo<IO> {
fn clone(&self) -> Self {
match self {
Self::Io(io) => Self::Io(io.clone()),
#[cfg(feature = "_tls-any")]
Self::TlsIo(io) => Self::TlsIo(io.clone()),
}
}
}

impl<IO> ServerIo<IO> {
pub(in crate::transport) fn new_io(io: IO) -> Self {
Expand All @@ -37,9 +111,9 @@ impl<IO> ServerIo<IO> {
IO: Connected,
{
match self {
Self::Io(io) => Either::Left(io.connect_info()),
Self::Io(io) => ServerIoConnectInfo::Io(io.connect_info()),
#[cfg(feature = "_tls-any")]
Self::TlsIo(io) => Either::Right(io.connect_info()),
Self::TlsIo(io) => ServerIoConnectInfo::TlsIo(io.connect_info()),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion tonic/src/transport/server/service/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod io;
pub(crate) use self::io::ServerIo;
pub(crate) use self::io::{ConnectInfoLayer, ServerIo};

mod recover_error;
pub(crate) use self::recover_error::RecoverError;
Expand Down

0 comments on commit ce043fd

Please sign in to comment.