From ebbafca8b3fa7a202c517c5ee2d7b81d10f41b9a Mon Sep 17 00:00:00 2001 From: Xylaant Date: Wed, 26 Jun 2024 08:28:23 -0400 Subject: [PATCH] feat: remove cancel task, which was causing a circular reference (#65) Provide user with a weak reference to Requester, when in Server mode, to allow socket to close properly --- rsocket/src/core/client.rs | 31 +- rsocket/src/core/server.rs | 3 +- rsocket/src/transport/mod.rs | 2 +- rsocket/src/transport/socket.rs | 520 +++++++++++++++++++------------- 4 files changed, 326 insertions(+), 230 deletions(-) diff --git a/rsocket/src/core/client.rs b/rsocket/src/core/client.rs index f846bb9..c014ca3 100644 --- a/rsocket/src/core/client.rs +++ b/rsocket/src/core/client.rs @@ -14,14 +14,14 @@ use crate::payload::{Payload, SetupPayload, SetupPayloadBuilder}; use crate::runtime; use crate::spi::{ClientResponder, Flux, RSocket}; use crate::transport::{ - self, Connection, DuplexSocket, FrameSink, FrameStream, Splitter, Transport, + self, Connection, DuplexSocket, FrameSink, FrameStream, ClientRequester, Splitter, Transport, }; use crate::Result; #[derive(Clone)] pub struct Client { closed: Arc, - socket: DuplexSocket, + requester: ClientRequester, closing: mpsc::Sender<()>, } @@ -130,9 +130,9 @@ where let (snd_tx, mut snd_rx) = mpsc::unbounded_channel::(); let cloned_snd_tx = snd_tx.clone(); - let mut socket = DuplexSocket::new(1, snd_tx, splitter).await; + let mut socket = DuplexSocket::new(1, snd_tx, splitter); - let mut cloned_socket = socket.clone(); + let requester = socket.client_requester(); if let Some(f) = self.responder { let responder = f(); @@ -211,10 +211,13 @@ where } }); + + socket.setup(setup).await?; + // process frames runtime::spawn(async move { while let Some(next) = read_rx.recv().await { - if let Err(e) = cloned_socket.dispatch(next, None).await { + if let Err(e) = socket.dispatch(next, None).await { error!("dispatch frame failed: {}", e); break; } @@ -237,16 +240,14 @@ where } }); - socket.setup(setup).await?; - - Ok(Client::new(socket, close_notify, closing)) + Ok(Client::new(requester, close_notify, closing)) } } impl Client { - fn new(socket: DuplexSocket, closed: Arc, closing: mpsc::Sender<()>) -> Client { + fn new(requester: ClientRequester, closed: Arc, closing: mpsc::Sender<()>) -> Client { Client { - socket, + requester, closed, closing, } @@ -260,22 +261,22 @@ impl Client { #[async_trait] impl RSocket for Client { async fn metadata_push(&self, req: Payload) -> Result<()> { - self.socket.metadata_push(req).await + self.requester.metadata_push(req).await } async fn fire_and_forget(&self, req: Payload) -> Result<()> { - self.socket.fire_and_forget(req).await + self.requester.fire_and_forget(req).await } async fn request_response(&self, req: Payload) -> Result> { - self.socket.request_response(req).await + self.requester.request_response(req).await } fn request_stream(&self, req: Payload) -> Flux> { - self.socket.request_stream(req) + self.requester.request_stream(req) } fn request_channel(&self, reqs: Flux>) -> Flux> { - self.socket.request_channel(reqs) + self.requester.request_channel(reqs) } } diff --git a/rsocket/src/core/server.rs b/rsocket/src/core/server.rs index 3b39fb5..5e79f1d 100644 --- a/rsocket/src/core/server.rs +++ b/rsocket/src/core/server.rs @@ -114,7 +114,7 @@ where // Init duplex socket. let (snd_tx, mut snd_rx) = mpsc::unbounded_channel::(); - let mut socket = DuplexSocket::new(0, snd_tx, splitter).await; + let mut socket = DuplexSocket::new(0, snd_tx, splitter); // Begin loop for writing frames. runtime::spawn(async move { @@ -154,7 +154,6 @@ where break; } } - Ok(()) } } diff --git a/rsocket/src/transport/mod.rs b/rsocket/src/transport/mod.rs index 9adddcb..d040d86 100644 --- a/rsocket/src/transport/mod.rs +++ b/rsocket/src/transport/mod.rs @@ -4,5 +4,5 @@ mod socket; mod spi; pub(crate) use fragmentation::{Joiner, Splitter, MIN_MTU}; -pub(crate) use socket::DuplexSocket; +pub(crate) use socket::{ClientRequester,DuplexSocket}; pub use spi::*; diff --git a/rsocket/src/transport/socket.rs b/rsocket/src/transport/socket.rs index fdc796f..de0ec93 100644 --- a/rsocket/src/transport/socket.rs +++ b/rsocket/src/transport/socket.rs @@ -1,6 +1,6 @@ use std::future::Future; use std::pin::Pin; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use async_stream::stream; use async_trait::async_trait; @@ -20,19 +20,31 @@ use crate::spi::{Flux, RSocket, ServerResponder}; use crate::utils::EmptyRSocket; use crate::{runtime, Result}; -#[derive(Clone)] -pub(crate) struct DuplexSocket { +struct DuplexSocketInner { seq: StreamID, responder: Responder, tx: mpsc::UnboundedSender, - handlers: Arc>, - canceller: mpsc::Sender, + handlers: DashMap, splitter: Option, - joiners: Arc>, - /// AbortHandles for streams and channels associated by sid + joiners: DashMap, + /// AbortHandles for Response futures/streams abort_handles: Arc>, } +#[derive(Clone)] +pub(crate) struct ClientRequester { + inner: Arc, +} + +#[derive(Clone)] +pub(crate) struct ServerRequester { + inner: Weak, +} + +pub(crate) struct DuplexSocket { + inner: Arc, +} + #[derive(Clone)] struct Responder { inner: Arc>>, @@ -41,36 +53,41 @@ struct Responder { #[derive(Debug)] enum Handler { ReqRR(oneshot::Sender>>), - ResRR(Counter), ReqRS(mpsc::Sender>), ReqRC(mpsc::Sender>), } -impl DuplexSocket { - pub(crate) async fn new( +struct Cancel {} + +impl DuplexSocketInner { + fn new( first_stream_id: u32, tx: mpsc::UnboundedSender, splitter: Option, - ) -> DuplexSocket { + ) -> Self { let (canceller_tx, canceller_rx) = mpsc::channel::(32); - let socket = DuplexSocket { + let this = Self { seq: StreamID::from(first_stream_id), tx, - canceller: canceller_tx, responder: Responder::new(), - handlers: Arc::new(DashMap::new()), - joiners: Arc::new(DashMap::new()), + handlers: DashMap::new(), + joiners: DashMap::new(), splitter, abort_handles: Arc::new(DashMap::new()), }; + this + } +} - let cloned_socket = socket.clone(); - - runtime::spawn(async move { - cloned_socket.loop_canceller(canceller_rx).await; - }); - - socket +impl DuplexSocket { + pub(crate) fn new( + first_stream_id: u32, + tx: mpsc::UnboundedSender, + splitter: Option, + ) -> DuplexSocket { + DuplexSocket { + inner: Arc::new(DuplexSocketInner::new(first_stream_id, tx, splitter)), + } } pub(crate) async fn setup(&mut self, setup: SetupPayload) -> Result<()> { @@ -90,19 +107,12 @@ impl DuplexSocket { if let Some(b) = m { bu = bu.set_metadata(b); } - self.tx.send(bu.build()).map_err(|e| e.into()) + self.inner.tx.send(bu.build()).map_err(|e| e.into()) } #[inline] - async fn register_handler(&self, sid: u32, handler: Handler) { - self.handlers.insert(sid, handler); - } - - #[inline] - async fn loop_canceller(&self, mut rx: mpsc::Receiver) { - while let Some(sid) = rx.recv().await { - self.handlers.remove(&sid); - } + fn register_handler(&self, sid: u32, handler: Handler) { + self.inner.handlers.insert(sid, handler); } pub(crate) async fn dispatch( @@ -132,7 +142,7 @@ impl DuplexSocket { .set_code(error::ERR_REJECT_SETUP) .set_data(Bytes::from(errmsg)) .build(); - if self.tx.send(sending).is_err() { + if self.inner.tx.send(sending).is_err() { error!("Reject setup failed"); } } @@ -198,7 +208,8 @@ impl DuplexSocket { let sid = input.get_stream_id(); if input.get_flag() & Frame::FLAG_FOLLOW != 0 { // TODO: check conflict - self.joiners + self.inner + .joiners .entry(sid) .or_insert_with(Joiner::new) .push(input); @@ -209,7 +220,7 @@ impl DuplexSocket { return Some(input); } - match self.joiners.remove(&sid) { + match self.inner.joiners.remove(&sid) { None => Some(input), Some((_, mut joiner)) => { joiner.push(input); @@ -263,9 +274,9 @@ impl DuplexSocket { #[inline] async fn on_error(&mut self, sid: u32, flag: u16, input: frame::Error) { - self.joiners.remove(&sid); + self.inner.joiners.remove(&sid); // pick handler - if let Some((_, handler)) = self.handlers.remove(&sid) { + if let Some((_, handler)) = self.inner.handlers.remove(&sid) { let desc = input .get_data_utf8() .map(|it| it.to_string()) @@ -277,7 +288,6 @@ impl DuplexSocket { error!("respond with error for REQUEST_RESPONSE failed!"); } } - Handler::ResRR(_) => unreachable!(), Handler::ReqRS(tx) => { if (tx.send(Err(e.into())).await).is_err() { error!("respond with error for REQUEST_STREAM failed!"); @@ -294,11 +304,11 @@ impl DuplexSocket { #[inline] async fn on_cancel(&mut self, sid: u32, _flag: u16) { - if let Some((sid, abort_handle)) = self.abort_handles.remove(&sid) { + if let Some((sid, abort_handle)) = self.inner.abort_handles.remove(&sid) { abort_handle.abort(); } - self.joiners.remove(&sid); - if let Some((_, handler)) = self.handlers.remove(&sid) { + self.inner.joiners.remove(&sid); + if let Some((_, handler)) = self.inner.handlers.remove(&sid) { let e: Result<_> = Err(RSocketError::RequestCancelled("request has been cancelled".into()).into()); match handler { @@ -308,10 +318,6 @@ impl DuplexSocket { error!("notify cancel for REQUEST_RESPONSE failed: sid={}", sid); } } - Handler::ResRR(c) => { - let lefts = c.count_down(); - info!("REQUEST_RESPONSE {} cancelled: lefts={}", sid, lefts); - } Handler::ReqRS(sender) => { info!("REQUEST_STREAM {} cancelled!", sid); } @@ -324,7 +330,7 @@ impl DuplexSocket { #[inline] async fn on_payload(&mut self, sid: u32, flag: u16, input: Payload) { - match self.handlers.entry(sid) { + match self.inner.handlers.entry(sid) { Entry::Occupied(o) => { match o.get() { Handler::ReqRR(_) => match o.remove() { @@ -339,7 +345,6 @@ impl DuplexSocket { } _ => unreachable!(), }, - Handler::ResRR(c) => unreachable!(), Handler::ReqRS(sender) => { if flag & Frame::FLAG_NEXT != 0 { if sender.is_closed() { @@ -379,13 +384,13 @@ impl DuplexSocket { #[inline] fn send_cancel_frame(&self, sid: u32) { let cancel_frame = frame::Cancel::builder(sid, Frame::FLAG_COMPLETE).build(); - if let Err(e) = self.tx.send(cancel_frame) { + if let Err(e) = self.inner.tx.send(cancel_frame) { error!("Sending CANCEL frame failed: sid={}, reason: {}", sid, e); } } pub(crate) async fn bind_responder(&self, responder: Box) { - self.responder.set(responder).await; + self.inner.responder.set(responder).await; } #[inline] @@ -398,12 +403,12 @@ impl DuplexSocket { ) -> Result<()> { match acceptor { None => { - self.responder.set(Box::new(EmptyRSocket)).await; + self.inner.responder.set(Box::new(EmptyRSocket)).await; Ok(()) } - Some(gen) => match gen(setup, Box::new(self.clone())) { + Some(gen) => match gen(setup, Box::new(self.server_requester())) { Ok(it) => { - self.responder.set(it).await; + self.inner.responder.set(it).await; Ok(()) } Err(e) => Err(e), @@ -413,36 +418,36 @@ impl DuplexSocket { #[inline] async fn on_fire_and_forget(&mut self, sid: u32, input: Payload) { - if let Err(e) = self.responder.fire_and_forget(input).await { + // TODO: Spawning a task here in case responer call goes pending, which would hold + // up the entire dispatch loop + if let Err(e) = self.inner.responder.fire_and_forget(input).await { error!("respond fire_and_forget failed: {:?}", e); } } #[inline] async fn on_request_response(&mut self, sid: u32, _flag: u16, input: Payload) { - let responder = self.responder.clone(); - let canceller = self.canceller.clone(); - let mut tx = self.tx.clone(); - let splitter = self.splitter.clone(); - let counter = Counter::new(2); - self.register_handler(sid, Handler::ResRR(counter.clone())) - .await; + let responder = self.inner.responder.clone(); + + let mut tx = self.inner.tx.clone(); + let splitter = self.inner.splitter.clone(); + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + let abort_handles = self.inner.abort_handles.clone(); runtime::spawn(async move { - // TODO: use future select - let result = responder.request_response(input).await; - if counter.count_down() == 0 { + + abort_handles.insert(sid, abort_handle); + let result= Abortable::new(responder.request_response(input), abort_registration).await; + abort_handles.remove(&sid); + + // Abort for futures adds an extra result wrapper, so unwrap that and continue + let Ok(result) = result else { // cancelled return; - } - - // async remove canceller - if (canceller.send(sid).await).is_err() { - error!("Send canceller failed: sid={}", sid); - } + }; match result { Ok(Some(res)) => { - Self::try_send_payload( + DuplexSocketInner::try_send_payload( &splitter, &mut tx, sid, @@ -452,7 +457,7 @@ impl DuplexSocket { .await; } Ok(None) => { - Self::try_send_complete(&mut tx, sid, Frame::FLAG_COMPLETE).await; + DuplexSocketInner::try_send_complete(&mut tx, sid, Frame::FLAG_COMPLETE).await; } Err(e) => { let sending = frame::Error::builder(sid, 0) @@ -469,10 +474,10 @@ impl DuplexSocket { #[inline] async fn on_request_stream(&self, sid: u32, flag: u16, input: Payload) { - let responder = self.responder.clone(); - let mut tx = self.tx.clone(); - let splitter = self.splitter.clone(); - let abort_handles = self.abort_handles.clone(); + let responder = self.inner.responder.clone(); + let mut tx = self.inner.tx.clone(); + let splitter = self.inner.splitter.clone(); + let abort_handles = self.inner.abort_handles.clone(); runtime::spawn(async move { let (abort_handle, abort_registration) = AbortHandle::new_pair(); abort_handles.insert(sid, abort_handle); @@ -480,7 +485,14 @@ impl DuplexSocket { while let Some(next) = payloads.next().await { match next { Ok(it) => { - Self::try_send_payload(&splitter, &mut tx, sid, it, Frame::FLAG_NEXT).await; + DuplexSocketInner::try_send_payload( + &splitter, + &mut tx, + sid, + it, + Frame::FLAG_NEXT, + ) + .await; } Err(e) => { let sending = frame::Error::builder(sid, 0) @@ -500,12 +512,12 @@ impl DuplexSocket { #[inline] async fn on_request_channel(&self, sid: u32, flag: u16, first: Payload) { - let responder = self.responder.clone(); - let tx = self.tx.clone(); + let responder = self.inner.responder.clone(); + let tx = self.inner.tx.clone(); let (sender, mut receiver) = mpsc::channel::>(32); sender.send(Ok(first)).await.expect("Send failed!"); - self.register_handler(sid, Handler::ReqRC(sender)).await; - let abort_handles = self.abort_handles.clone(); + self.register_handler(sid, Handler::ReqRC(sender)); + let abort_handles = self.inner.abort_handles.clone(); runtime::spawn(async move { // respond client channel let outputs = responder.request_channel(Box::pin(stream! { @@ -554,7 +566,9 @@ impl DuplexSocket { #[inline] async fn on_metadata_push(&mut self, input: Payload) { - if let Err(e) = self.responder.metadata_push(input).await { + // TODO: Spawning a task here in case responer call goes pending, which would hold + // up the entire dispatch loop + if let Err(e) = self.inner.responder.metadata_push(input).await { error!("response metadata_push failed: {:?}", e); } } @@ -566,139 +580,26 @@ impl DuplexSocket { if let Some(b) = data { sending = sending.set_data(b); } - if let Err(e) = self.tx.send(sending.build()) { + if let Err(e) = self.inner.tx.send(sending.build()) { error!("respond KEEPALIVE failed: {}", e); } } - #[inline] - async fn try_send_channel( - splitter: &Option, - tx: &mut mpsc::UnboundedSender, - sid: u32, - res: Payload, - flag: u16, - ) { - // TODO - match splitter { - Some(sp) => { - let mut cuts: usize = 0; - let mut prev: Option = None; - for next in sp.cut(res, 4) { - if let Some(cur) = prev.take() { - let sending = if cuts == 1 { - frame::RequestChannel::builder(sid, flag | Frame::FLAG_FOLLOW) - .set_all(cur.split()) - .build() - } else { - frame::Payload::builder(sid, Frame::FLAG_FOLLOW) - .set_all(cur.split()) - .build() - }; - // send frame - if let Err(e) = tx.send(sending) { - error!("send request_channel failed: {}", e); - return; - } - } - prev = Some(next); - cuts += 1; - } - - let sending = if cuts == 0 { - frame::RequestChannel::builder(sid, flag).build() - } else if cuts == 1 { - frame::RequestChannel::builder(sid, flag) - .set_all(prev.unwrap().split()) - .build() - } else { - frame::Payload::builder(sid, 0) - .set_all(prev.unwrap().split()) - .build() - }; - // send frame - if let Err(e) = tx.send(sending) { - error!("send request_channel failed: {}", e); - } - } - None => { - let sending = frame::RequestChannel::builder(sid, flag) - .set_all(res.split()) - .build(); - if let Err(e) = tx.send(sending) { - error!("send request_channel failed: {}", e); - } - } + pub(crate) fn client_requester(&self) -> ClientRequester { + ClientRequester { + inner: self.inner.clone(), } } - #[inline] - async fn try_send_complete(tx: &mut mpsc::UnboundedSender, sid: u32, flag: u16) { - let sending = frame::Payload::builder(sid, flag).build(); - if let Err(e) = tx.send(sending) { - error!("respond failed: {}", e); - } - } - - #[inline] - async fn try_send_payload( - splitter: &Option, - tx: &mut mpsc::UnboundedSender, - sid: u32, - res: Payload, - flag: u16, - ) { - match splitter { - Some(sp) => { - let mut cuts: usize = 0; - let mut prev: Option = None; - for next in sp.cut(res, 0) { - if let Some(cur) = prev.take() { - let sending = if cuts == 1 { - frame::Payload::builder(sid, flag | Frame::FLAG_FOLLOW) - .set_all(cur.split()) - .build() - } else { - frame::Payload::builder(sid, Frame::FLAG_FOLLOW) - .set_all(cur.split()) - .build() - }; - // send frame - if let Err(e) = tx.send(sending) { - error!("send payload failed: {}", e); - return; - } - } - prev = Some(next); - cuts += 1; - } - - let sending = if cuts == 0 { - frame::Payload::builder(sid, flag).build() - } else { - frame::Payload::builder(sid, flag) - .set_all(prev.unwrap().split()) - .build() - }; - // send frame - if let Err(e) = tx.send(sending) { - error!("send payload failed: {}", e); - } - } - None => { - let sending = frame::Payload::builder(sid, flag) - .set_all(res.split()) - .build(); - if let Err(e) = tx.send(sending) { - error!("respond failed: {}", e); - } - } + pub(crate) fn server_requester(&self) -> ServerRequester { + ServerRequester { + inner: Arc::downgrade(&self.inner), } } } -#[async_trait] -impl RSocket for DuplexSocket { +// These are the immplementation functions for the requesters below +impl DuplexSocketInner { async fn metadata_push(&self, req: Payload) -> Result<()> { let sid = self.seq.next(); let tx = self.tx.clone(); @@ -767,14 +668,13 @@ impl RSocket for DuplexSocket { async fn request_response(&self, req: Payload) -> Result> { let (tx, rx) = oneshot::channel::>>(); let sid = self.seq.next(); - let handlers = self.handlers.clone(); let sender = self.tx.clone(); - let splitter = self.splitter.clone(); + // Register handler + self.handlers.insert(sid, Handler::ReqRR(tx)); + runtime::spawn(async move { - // register handler - handlers.insert(sid, Handler::ReqRR(tx)); match splitter { Some(sp) => { let mut cuts: usize = 0; @@ -841,10 +741,9 @@ impl RSocket for DuplexSocket { let tx = self.tx.clone(); // register handler let (sender, mut receiver) = mpsc::channel::>(32); - let handlers = self.handlers.clone(); + self.handlers.insert(sid, Handler::ReqRS(sender)); let splitter = self.splitter.clone(); runtime::spawn(async move { - handlers.insert(sid, Handler::ReqRS(sender)); match splitter { Some(sp) => { let mut cuts: usize = 0; @@ -909,12 +808,12 @@ impl RSocket for DuplexSocket { fn request_channel(&self, mut reqs: Flux>) -> Flux> { let sid = self.seq.next(); let mut tx = self.tx.clone(); - // register handler + let (sender, mut receiver) = mpsc::channel::>(32); - let handlers = self.handlers.clone(); + // register handler + self.handlers.insert(sid, Handler::ReqRC(sender)); let splitter = self.splitter.clone(); runtime::spawn(async move { - handlers.insert(sid, Handler::ReqRC(sender)); let mut first = true; while let Some(next) = reqs.next().await { match next { @@ -950,6 +849,131 @@ impl RSocket for DuplexSocket { } }) } + + #[inline] + async fn try_send_channel( + splitter: &Option, + tx: &mut mpsc::UnboundedSender, + sid: u32, + res: Payload, + flag: u16, + ) { + // TODO + match splitter { + Some(sp) => { + let mut cuts: usize = 0; + let mut prev: Option = None; + for next in sp.cut(res, 4) { + if let Some(cur) = prev.take() { + let sending = if cuts == 1 { + frame::RequestChannel::builder(sid, flag | Frame::FLAG_FOLLOW) + .set_all(cur.split()) + .build() + } else { + frame::Payload::builder(sid, Frame::FLAG_FOLLOW) + .set_all(cur.split()) + .build() + }; + // send frame + if let Err(e) = tx.send(sending) { + error!("send request_channel failed: {}", e); + return; + } + } + prev = Some(next); + cuts += 1; + } + + let sending = if cuts == 0 { + frame::RequestChannel::builder(sid, flag).build() + } else if cuts == 1 { + frame::RequestChannel::builder(sid, flag) + .set_all(prev.unwrap().split()) + .build() + } else { + frame::Payload::builder(sid, 0) + .set_all(prev.unwrap().split()) + .build() + }; + // send frame + if let Err(e) = tx.send(sending) { + error!("send request_channel failed: {}", e); + } + } + None => { + let sending = frame::RequestChannel::builder(sid, flag) + .set_all(res.split()) + .build(); + if let Err(e) = tx.send(sending) { + error!("send request_channel failed: {}", e); + } + } + } + } + + #[inline] + async fn try_send_complete(tx: &mut mpsc::UnboundedSender, sid: u32, flag: u16) { + let sending = frame::Payload::builder(sid, flag).build(); + if let Err(e) = tx.send(sending) { + error!("respond failed: {}", e); + } + } + + #[inline] + async fn try_send_payload( + splitter: &Option, + tx: &mut mpsc::UnboundedSender, + sid: u32, + res: Payload, + flag: u16, + ) { + match splitter { + Some(sp) => { + let mut cuts: usize = 0; + let mut prev: Option = None; + for next in sp.cut(res, 0) { + if let Some(cur) = prev.take() { + let sending = if cuts == 1 { + frame::Payload::builder(sid, flag | Frame::FLAG_FOLLOW) + .set_all(cur.split()) + .build() + } else { + frame::Payload::builder(sid, Frame::FLAG_FOLLOW) + .set_all(cur.split()) + .build() + }; + // send frame + if let Err(e) = tx.send(sending) { + error!("send payload failed: {}", e); + return; + } + } + prev = Some(next); + cuts += 1; + } + + let sending = if cuts == 0 { + frame::Payload::builder(sid, flag).build() + } else { + frame::Payload::builder(sid, flag) + .set_all(prev.unwrap().split()) + .build() + }; + // send frame + if let Err(e) = tx.send(sending) { + error!("send payload failed: {}", e); + } + } + None => { + let sending = frame::Payload::builder(sid, flag) + .set_all(res.split()) + .build(); + if let Err(e) = tx.send(sending) { + error!("respond failed: {}", e); + } + } + } + } } impl From> for Responder { @@ -1013,3 +1037,75 @@ impl RSocket for Responder { }) } } + +#[async_trait] +impl RSocket for ClientRequester { + /// Metadata-Push interaction model of RSocket. + async fn metadata_push(&self, req: Payload) -> Result<()> { + self.inner.metadata_push(req).await + } + /// Fire and Forget interaction model of RSocket. + async fn fire_and_forget(&self, req: Payload) -> Result<()> { + self.inner.fire_and_forget(req).await + } + /// Request-Response interaction model of RSocket. + async fn request_response(&self, req: Payload) -> Result> { + self.inner.request_response(req).await + } + /// Request-Stream interaction model of RSocket. + fn request_stream(&self, req: Payload) -> Flux> { + self.inner.request_stream(req) + } + /// Request-Channel interaction model of RSocket. + fn request_channel(&self, reqs: Flux>) -> Flux> { + self.inner.request_channel(reqs) + } +} + +// This implementation, uses the async function trait, as it needs to capture the ugpraded +// Arc, at least until it returns +#[async_trait] +impl RSocket for ServerRequester { + /// Metadata-Push interaction model of RSocket. + async fn metadata_push(&self, req: Payload) -> Result<()> { + match self.inner.upgrade() { + Some(inner) => inner.metadata_push(req).await, + None => Err(RSocketError::ConnectionClosed("close".into()).into()), + } + } + /// Fire and Forget interaction model of RSocket. + async fn fire_and_forget(&self, req: Payload) -> Result<()> { + match self.inner.upgrade() { + Some(inner) => inner.fire_and_forget(req).await, + None => Err(RSocketError::ConnectionClosed("closed".into()).into()), + } + } + /// Request-Response interaction model of RSocket. + async fn request_response(&self, req: Payload) -> Result> { + match self.inner.upgrade() { + Some(inner) => inner.request_response(req).await, + None => Err(RSocketError::ConnectionClosed("closed".into()).into()), + } + } + /// Request-Stream interaction model of RSocket. + fn request_stream(&self, req: Payload) -> Flux> { + use futures::{future, stream}; + + match self.inner.upgrade() { + Some(inner) => inner.request_stream(req), + None => Box::pin(stream::once(future::ready(Err( + RSocketError::ConnectionClosed("closed".into()).into(), + )))), + } + } + /// Request-Channel interaction model of RSocket. + fn request_channel(&self, reqs: Flux>) -> Flux> { + use futures::{future, stream}; + match self.inner.upgrade() { + Some(inner) => inner.request_channel(reqs), + None => Box::pin(stream::once(future::ready(Err( + RSocketError::ConnectionClosed("closed".into()).into(), + )))), + } + } +}