Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: syncronize ws ping/pong messages #1437

Closed
wants to merge 16 commits into from
8 changes: 6 additions & 2 deletions examples/examples/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
use std::net::SocketAddr;

use jsonrpsee::core::client::ClientT;
use jsonrpsee::server::{RpcServiceBuilder, Server};
use jsonrpsee::server::{PingConfig, RpcServiceBuilder, Server};
use jsonrpsee::ws_client::WsClientBuilder;
use jsonrpsee::{rpc_params, RpcModule};
use tracing_subscriber::util::SubscriberInitExt;
Expand All @@ -51,7 +51,11 @@ async fn main() -> anyhow::Result<()> {

async fn run_server() -> anyhow::Result<SocketAddr> {
let rpc_middleware = RpcServiceBuilder::new().rpc_logger(1024);
let server = Server::builder().set_rpc_middleware(rpc_middleware).build("127.0.0.1:0").await?;
let server = Server::builder()
.enable_ws_ping(PingConfig::new())
.set_rpc_middleware(rpc_middleware)
.build("127.0.0.1:0")
.await?;
let mut module = RpcModule::new(());
module.register_method("say_hello", |_, _, _| "lo")?;
let addr = server.local_addr()?;
Expand Down
55 changes: 52 additions & 3 deletions server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

use std::error::Error as StdError;
use std::future::Future;
use std::net::{SocketAddr, TcpListener as StdTcpListener};
use std::net::{IpAddr, SocketAddr, TcpListener as StdTcpListener};
use std::pin::Pin;
use std::sync::atomic::AtomicU32;
use std::sync::Arc;
Expand Down Expand Up @@ -240,6 +240,8 @@ pub struct TowerServiceBuilder<RpcMiddleware, HttpMiddleware> {
pub(crate) conn_id: Arc<AtomicU32>,
/// Connection guard.
pub(crate) conn_guard: ConnectionGuard,
/// IP address.
pub(crate) ip_addr: Option<IpAddr>,
}

/// Configuration for batch request handling.
Expand Down Expand Up @@ -461,6 +463,7 @@ pub struct Builder<HttpMiddleware, RpcMiddleware> {
server_cfg: ServerConfig,
rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
http_middleware: tower::ServiceBuilder<HttpMiddleware>,
ip_addr: Option<IpAddr>,
}

impl Default for Builder<Identity, Identity> {
Expand All @@ -469,6 +472,7 @@ impl Default for Builder<Identity, Identity> {
server_cfg: ServerConfig::default(),
rpc_middleware: RpcServiceBuilder::new(),
http_middleware: tower::ServiceBuilder::new(),
ip_addr: None,
}
}
}
Expand All @@ -478,6 +482,12 @@ impl Builder<Identity, Identity> {
pub fn new() -> Self {
Self::default()
}

/// Set the address of the remote peer.
pub fn set_ip_addr(mut self, ip_addr: IpAddr) -> Self {
self.ip_addr = Some(ip_addr);
self
}
}

impl<RpcMiddleware, HttpMiddleware> TowerServiceBuilder<RpcMiddleware, HttpMiddleware> {
Expand All @@ -497,6 +507,7 @@ impl<RpcMiddleware, HttpMiddleware> TowerServiceBuilder<RpcMiddleware, HttpMiddl
conn_id,
conn_guard: self.conn_guard,
server_cfg: self.server_cfg,
ip_addr: self.ip_addr,
},
on_session_close: None,
};
Expand Down Expand Up @@ -526,6 +537,7 @@ impl<RpcMiddleware, HttpMiddleware> TowerServiceBuilder<RpcMiddleware, HttpMiddl
http_middleware: self.http_middleware,
conn_id: self.conn_id,
conn_guard: self.conn_guard,
ip_addr: self.ip_addr,
}
}

Expand All @@ -540,6 +552,19 @@ impl<RpcMiddleware, HttpMiddleware> TowerServiceBuilder<RpcMiddleware, HttpMiddl
http_middleware,
conn_id: self.conn_id,
conn_guard: self.conn_guard,
ip_addr: self.ip_addr,
}
}

/// Set the address of the remote peer.
pub fn set_ip_addr(self, ip_addr: IpAddr) -> TowerServiceBuilder<RpcMiddleware, HttpMiddleware> {
TowerServiceBuilder {
server_cfg: self.server_cfg,
rpc_middleware: self.rpc_middleware,
http_middleware: self.http_middleware,
conn_id: self.conn_id,
conn_guard: self.conn_guard,
ip_addr: Some(ip_addr),
}
}
}
Expand Down Expand Up @@ -637,7 +662,12 @@ impl<HttpMiddleware, RpcMiddleware> Builder<HttpMiddleware, RpcMiddleware> {
/// let builder = ServerBuilder::default().set_rpc_middleware(m);
/// ```
pub fn set_rpc_middleware<T>(self, rpc_middleware: RpcServiceBuilder<T>) -> Builder<HttpMiddleware, T> {
Builder { server_cfg: self.server_cfg, rpc_middleware, http_middleware: self.http_middleware }
Builder {
server_cfg: self.server_cfg,
rpc_middleware,
http_middleware: self.http_middleware,
ip_addr: self.ip_addr,
}
}

/// Configure a custom [`tokio::runtime::Handle`] to run the server on.
Expand Down Expand Up @@ -723,7 +753,12 @@ impl<HttpMiddleware, RpcMiddleware> Builder<HttpMiddleware, RpcMiddleware> {
/// }
/// ```
pub fn set_http_middleware<T>(self, http_middleware: tower::ServiceBuilder<T>) -> Builder<T, RpcMiddleware> {
Builder { server_cfg: self.server_cfg, http_middleware, rpc_middleware: self.rpc_middleware }
Builder {
server_cfg: self.server_cfg,
http_middleware,
rpc_middleware: self.rpc_middleware,
ip_addr: self.ip_addr,
}
}

/// Configure `TCP_NODELAY` on the socket to the supplied value `nodelay`.
Expand Down Expand Up @@ -859,6 +894,7 @@ impl<HttpMiddleware, RpcMiddleware> Builder<HttpMiddleware, RpcMiddleware> {
http_middleware: self.http_middleware,
conn_id: Arc::new(AtomicU32::new(0)),
conn_guard: ConnectionGuard::new(max_conns),
ip_addr: self.ip_addr,
}
}

Expand Down Expand Up @@ -940,6 +976,8 @@ struct ServiceData {
conn_guard: ConnectionGuard,
/// ServerConfig
server_cfg: ServerConfig,
/// IP address.
ip_addr: Option<IpAddr>,
}

/// jsonrpsee tower service
Expand Down Expand Up @@ -1049,6 +1087,15 @@ where

request.extensions_mut().insert::<ConnectionId>(conn.conn_id.into());

if let Some(ip_addr) = self.inner.ip_addr {
// Only insert the remote address if it's not already set.
// We expect servers deployed behind a reverse proxy to set the remote address
// themselves otherwise the remote address will be the address of the reverse proxy.
if request.extensions().get::<IpAddr>().is_none() {
request.extensions_mut().insert(ip_addr);
}
}

let is_upgrade_request = is_upgrade_request(&request);

if self.inner.server_cfg.enable_ws && is_upgrade_request {
Expand Down Expand Up @@ -1190,6 +1237,7 @@ where
stop_handle,
drop_on_completion,
methods,
remote_addr,
..
} = params;

Expand All @@ -1205,6 +1253,7 @@ where
stop_handle: stop_handle.clone(),
conn_id,
conn_guard: conn_guard.clone(),
ip_addr: Some(remote_addr.ip()),
},
rpc_middleware,
on_session_close: None,
Expand Down
Loading
Loading