diff --git a/examples/client.rs b/examples/client.rs index d215b2e..5bdb039 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -3,13 +3,13 @@ //! //! You can try this example by running: //! -//! cargo run --example server +//! cargo run --example server //! //! And then start client in another terminal by running: //! -//! cargo run --example client +//! cargo run --example client -use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, Rdma, RdmaBuilder}; +use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, MrAccess, RCStream, Rdma, RdmaBuilder}; use std::{ alloc::Layout, env, @@ -118,6 +118,35 @@ async fn request_then_write_cas(rdma: &Rdma) -> io::Result<()> { Ok(()) } +async fn rcstream_send(stream: &mut RCStream) -> io::Result<()> { + for i in 0..10 { + // alloc 8 bytes local memory + let mut lmr = stream.alloc_local_mr(Layout::new::<[u8; 8]>())?; + // write data into lmr + let _num = lmr.as_mut_slice().write(&[i as u8; 8])?; + // send data in mr to the remote end + stream.send_lmr(lmr).await?; + println!("stream send datagram {} ", i); + } + Ok(()) +} + +async fn rcstream_recv(stream: &mut RCStream) -> io::Result<()> { + for i in 0..10 { + // recieve data from the remote end + let mut lmr_vec = stream.recieve_lmr(8).await?; + println!("stream recieve datagram {}", i); + // check the length of the recieved data + assert!(lmr_vec.len() == 1); + let lmr = lmr_vec.pop().unwrap(); + assert!(lmr.length() == 8); + let buff = *(lmr.as_slice()); + // check the data + assert_eq!(buff, [i as u8; 8]); + } + Ok(()) +} + #[tokio::main] async fn main() { println!("client start"); @@ -153,5 +182,8 @@ async fn main() { request_then_write_with_imm(&rdma).await.unwrap(); request_then_write_cas(&rdma).await.unwrap(); } + let mut stream: RCStream = rdma.into(); + rcstream_send(&mut stream).await.unwrap(); + rcstream_recv(&mut stream).await.unwrap(); println!("client done"); } diff --git a/examples/server.rs b/examples/server.rs index 713cb06..f91caa5 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -3,14 +3,15 @@ //! //! You can try this example by running: //! -//! cargo run --example server +//! cargo run --example server //! //! And start client in another terminal by running: //! -//! cargo run --example client +//! cargo run --example client -use async_rdma::{LocalMrReadAccess, Rdma, RdmaBuilder}; +use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, MrAccess, RCStream, Rdma, RdmaBuilder}; use clippy_utilities::Cast; +use std::io::Write; use std::{alloc::Layout, env, io, process::exit}; /// receive data from client @@ -90,6 +91,35 @@ async fn receive_mr_after_being_written_by_cas(rdma: &Rdma) -> io::Result<()> { Ok(()) } +async fn rcstream_send(stream: &mut RCStream) -> io::Result<()> { + for i in 0..10 { + // alloc 8 bytes local memory + let mut lmr = stream.alloc_local_mr(Layout::new::<[u8; 8]>())?; + // write data into lmr + let _num = lmr.as_mut_slice().write(&[i as u8; 8])?; + // send data in mr to the remote end + stream.send_lmr(lmr).await?; + println!("stream send datagram {} ", i); + } + Ok(()) +} + +async fn rcstream_recv(stream: &mut RCStream) -> io::Result<()> { + for i in 0..10 { + // recieve data from the remote end + let mut lmr_vec = stream.recieve_lmr(8).await?; + println!("stream recieve datagram {}", i); + // check the length of the recieved data + assert!(lmr_vec.len() == 1); + let lmr = lmr_vec.pop().unwrap(); + assert!(lmr.length() == 8); + let buff = *(lmr.as_slice()); + // check the data + assert_eq!(buff, [i as u8; 8]); + } + Ok(()) +} + #[tokio::main] async fn main() { println!("server start"); @@ -129,5 +159,8 @@ async fn main() { .unwrap(); receive_mr_after_being_written_by_cas(&rdma).await.unwrap(); } + let mut stream: RCStream = rdma.into(); + rcstream_recv(&mut stream).await.unwrap(); + rcstream_send(&mut stream).await.unwrap(); println!("server done"); } diff --git a/src/agent.rs b/src/agent.rs index 6670eaf..1b494f7 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -1,7 +1,7 @@ use crate::context::Context; use crate::hashmap_extension::HashMapExtension; use crate::ibv_event_listener::IbvEventListener; -use crate::queue_pair::MAX_RECV_WR; +use crate::queue_pair::{QPSendOwn, QueuePairOp, QueuePairOpsInflight, MAX_RECV_WR}; use crate::rmr_manager::RemoteMrManager; use crate::RemoteMrReadAccess; use crate::{ @@ -249,6 +249,23 @@ impl Agent { Ok(()) } + /// Send the content in the `lm` to the other side + pub(crate) async fn submit_send_data( + &self, + lms: Vec, + imm: Option, + ) -> io::Result>> { + let lm_len = lms.iter().map(|lm| lm.length()).sum::(); + assert!(lm_len <= self.max_msg_len()); + let kind = RequestKind::SendData(SendDataRequest { len: lm_len }); + let req_submitted = self + .inner + // SAFETY: The input range is always valid + .submit_send_request_append_data(kind, lms, imm) + .await?; + Ok(req_submitted) + } + /// Receive content sent from the other side and stored in the `LocalMr` pub(crate) async fn receive_data(&self) -> io::Result<(LocalMr, Option)> { let (lmr, len, imm) = self @@ -689,6 +706,45 @@ impl AgentInner { } } + /// submit a send request with data appended + async fn submit_send_request_append_data( + &self, + kind: RequestKind, + data: Vec, + imm: Option, + ) -> io::Result>> { + let data_len: usize = data.iter().map(|l| l.length()).sum(); + assert!(data_len <= self.max_sr_data_len); + let (tx, rx) = channel(2); + let req_id = self + .response_waits + .lock() + .insert_until_success(tx, AgentRequestId::new); + let req = Request { + request_id: req_id, + kind, + }; + // SAFETY: ? + // TODO: check safety + let mut header_buf = self + .allocator + // alignment 1 is always correct + .alloc_zeroed_default(unsafe { + &Layout::from_size_align_unchecked(*REQUEST_HEADER_MAX_LEN, 1) + })?; + // SAFETY: the mr is writeable here without cancel safety issue + let cursor = Cursor::new(unsafe { header_buf.as_mut_slice_unchecked() }); + let message = Message::Request(req); + // FIXME: serialize udpate + bincode::serialize_into(cursor, &message) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + // SAFETY: The input range is always valid + let mut lmrs = vec![header_buf]; + lmrs.extend(data); + let inflight = self.qp.submit_send_sge(lmrs, imm).await?; + Ok(RequestSubmitted::new(inflight, rx)) + } + /// Send a response to the other side async fn send_response(&self, response: Response) -> io::Result<()> { // SAFETY: ? @@ -850,7 +906,7 @@ struct AllocMRRequest { /// Response to the alloc MR request #[derive(Debug, Serialize, Deserialize)] -struct AllocMRResponse { +pub(crate) struct AllocMRResponse { /// The token to access the MR token: MrToken, } @@ -864,7 +920,7 @@ struct ReleaseMRRequest { /// Response to the release MR request #[derive(Debug, Serialize, Deserialize)] -struct ReleaseMRResponse { +pub(crate) struct ReleaseMRResponse { /// The status of the operation status: usize, } @@ -887,7 +943,7 @@ struct SendMRRequest { /// Response to the request of sending MR #[derive(Debug, Serialize, Deserialize)] -struct SendMRResponse { +pub(crate) struct SendMRResponse { /// The kinds of Response to the request of sending MR kind: SendMRResponseKind, } @@ -911,9 +967,9 @@ struct SendDataRequest { /// Response to the request of sending data #[derive(Debug, Serialize, Deserialize)] -struct SendDataResponse { +pub(crate) struct SendDataResponse { /// response status - status: usize, + pub(crate) status: usize, } /// Request type enumeration @@ -941,7 +997,7 @@ struct Request { /// Response type enumeration #[derive(Serialize, Deserialize, Debug)] #[allow(variant_size_differences)] -enum ResponseKind { +pub(crate) enum ResponseKind { /// Allocate MR AllocMR(AllocMRResponse), /// Release MR @@ -969,3 +1025,36 @@ enum Message { /// Response Response(Response), } + +/// Queue pair operation submitted in wq, waitting for wc & response +#[derive(Debug)] +pub(crate) struct RequestSubmitted { + /// the operation of the request + inflight: QueuePairOpsInflight, + /// receiver for the response of the request + rx: Receiver>, +} + +impl RequestSubmitted { + /// Create a new `RequestSubmitted` + fn new( + inflight: QueuePairOpsInflight, + rx: Receiver>, + ) -> Self { + Self { inflight, rx } + } + + /// Wait for the response of the request + pub(crate) async fn response(mut self) -> io::Result { + let _ = self.inflight.result().await?; + match tokio::time::timeout(RESPONSE_TIMEOUT, self.rx.recv()).await { + Ok(resp) => { + resp.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "agent is dropped"))? + } + Err(_) => Err(io::Error::new( + io::ErrorKind::TimedOut, + "Timeout for waiting for a response.", + )), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 1983409..b521030 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -166,7 +166,7 @@ mod work_request; use access::flags_into_ibv_access; pub use access::AccessFlag; -use agent::{Agent, MAX_MSG_LEN}; +use agent::{Agent, RequestSubmitted, ResponseKind, MAX_MSG_LEN}; use clippy_utilities::Cast; use completion_queue::{DEFAULT_CQ_SIZE, DEFAULT_MAX_CQE}; use context::Context; @@ -184,7 +184,7 @@ pub use mr_allocator::MRManageStrategy; use mr_allocator::MrAllocator; use protection_domain::ProtectionDomain; use queue_pair::{ - QueuePair, QueuePairInitAttrBuilder, RQAttr, RQAttrBuilder, SQAttr, SQAttrBuilder, + QPSendOwn, QueuePair, QueuePairInitAttrBuilder, RQAttr, RQAttrBuilder, SQAttr, SQAttrBuilder, }; use rdma_sys::ibv_access_flags; #[cfg(feature = "cm")] @@ -195,7 +195,7 @@ use rdma_sys::{ use rmr_manager::DEFAULT_RMR_TIMEOUT; #[cfg(feature = "cm")] use std::ptr::null_mut; -use std::{alloc::Layout, fmt::Debug, io, sync::Arc, time::Duration}; +use std::{alloc::Layout, collections::BTreeMap, fmt::Debug, io, sync::Arc, time::Duration}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream, ToSocketAddrs}, @@ -1672,6 +1672,39 @@ impl Rdma { .await } + /// submit send of the `lm` + /// + /// Used with `receive`. + #[allow(unused)] + #[inline] + async fn submit_send( + &self, + lm: Vec, + ) -> io::Result>> { + self.agent + .as_ref() + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Agent is not ready"))? + .submit_send_data(lm, None) + .await + } + + /// submit send of the `lm` with imm + /// + /// Used with `receive_with_imm`. + #[inline] + async fn submit_send_with_imm( + &self, + lm: Vec, + imm: u32, + ) -> io::Result>> { + debug!("submit send seq_id {:?}", imm); + self.agent + .as_ref() + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Agent is not ready"))? + .submit_send_data(lm, Some(imm)) + .await + } + /// A 64 bits value in a remote mr being read, compared with `old_value` and if they are equal, /// the `new_value` is being written to the remote mr in an atomic way. /// @@ -4057,6 +4090,160 @@ impl Rdma { } } +/// The wrapper of a RDMA RC connection, with convient methods to write and read and order guarantee. +/// TODO how to close stream +#[derive(Debug)] +pub struct RCStream { + /// inner rdma transport + inner: Rdma, + /// current send sequence number + send_seq: u32, + /// current recv sequence number + recv_seq: u32, + /// current lmr to read + read_buf: Option, + /// received lmr buffer, rdma will recv data from many AgentThread, which may cause out of order + recv_buf: BTreeMap, + /// sender of inflight requests, anthor task will wait for the completion of the request + inflights_tx: mpsc::Sender>>, +} + +impl RCStream { + /// Create a new `RCStream` with Rdma + #[must_use] + pub fn new(inner: Rdma) -> Self { + let (inflights_tx, mut inflights_rx) = mpsc::channel(1024); + let _ = tokio::spawn(async move { + while let Some(inflight) = inflights_rx.recv().await { + let _ = Self::handle_send_wc(inflight).await; + } + }); + RCStream { + inner, + send_seq: 0, + recv_seq: 0, + read_buf: None, + recv_buf: BTreeMap::new(), + inflights_tx, + } + } + + /// handle the send wc, and check `ResponseKind` + async fn handle_send_wc(inflight: RequestSubmitted>) -> io::Result<()> { + let resp = inflight.response().await?; + if let ResponseKind::SendData(send_data_resp) = resp { + if send_data_resp.status > 0 { + return Err(io::Error::new( + io::ErrorKind::Other, + format!( + "send data failed, response status is {}", + send_data_resp.status + ), + )); + } + } else { + return Err(io::Error::new( + io::ErrorKind::Other, + format!( + "send data failed, due to unexpected response type {:?}", + resp + ), + )); + } + Ok(()) + } + + /// Send a `LocalMr` whose size is less than `max_message_length` to the remote peer. + pub async fn send_lmr_segment(&mut self, lmr_segment: LocalMr) -> io::Result<()> { + let inflight = self + .inner + .submit_send_with_imm(vec![lmr_segment], self.send_seq) + .await?; + self.send_seq = self.send_seq.wrapping_add(1); + match self.inflights_tx.send(inflight).await { + Ok(_) => Ok(()), + Err(_) => Err(io::Error::new( + io::ErrorKind::Other, + "inflight queue is full", + )), + } + } + + /// Send a `LocalMr` to the remote peer. + pub async fn send_lmr(&mut self, mut lmr: LocalMr) -> io::Result<()> { + while lmr.length() > self.inner.clone_attr.agent_attr.max_message_length { + // split to multiple lmr segments + let lmr_segment = lmr.split_to(self.inner.clone_attr.agent_attr.max_message_length); + self.send_lmr_segment(lmr_segment.unwrap()).await?; + } + self.send_lmr_segment(lmr).await?; + Ok(()) + } + + /// wait for next data lmr + async fn read_next_lmr(&mut self) -> io::Result<()> { + loop { + match self.recv_buf.remove(&self.recv_seq) { + Some(lmr) => { + self.read_buf = Some(lmr); + self.recv_seq = self.recv_seq.wrapping_add(1); + break; + } + None => { + // if next lmr is not in recv_buf, wait for next recv and check again + let (lmr, seq) = self.inner.receive_with_imm().await?; + debug!("recieve seq: {:?}", seq); + match self.recv_buf.insert(seq.unwrap(), lmr) { + Some(_) => { + return Err(io::Error::new( + io::ErrorKind::Other, + "recv_buf can not contain duplicated seq", + )) + } + None => {} + } + } + } + } + Ok(()) + } + + /// recieve `LocalMrs` whose total size equal to the given size + /// TODO how to read eof? + pub async fn recieve_lmr(&mut self, mut fill_size: usize) -> io::Result> { + let mut ret_lmr = vec![]; + while fill_size > 0 { + if let Some(mut lmr) = self.read_buf.take() { + if lmr.length() > fill_size { + let readed = lmr.split_to(fill_size); + self.read_buf = Some(lmr); + ret_lmr.push(readed.unwrap()); + fill_size = 0; + } else { + fill_size -= lmr.length(); + ret_lmr.push(lmr); + } + } else { + self.read_next_lmr().await?; + } + } + Ok(ret_lmr) + } + + /// Allocate a local memory region + /// The parameter `layout` can be obtained by `Layout::new::()`. + pub fn alloc_local_mr(&mut self, layout: Layout) -> io::Result { + self.inner.alloc_local_mr(layout) + } +} + +impl From for RCStream { + /// Create a `RCStream` from a Rdma + fn from(rdma: Rdma) -> Self { + Self::new(rdma) + } +} + /// Rdma Listener is the wrapper of a `TcpListener`, which is used to /// build the rdma queue pair. #[derive(Debug)] diff --git a/src/memory_region/local.rs b/src/memory_region/local.rs index 3b53e6a..a89229d 100644 --- a/src/memory_region/local.rs +++ b/src/memory_region/local.rs @@ -515,6 +515,57 @@ impl LocalMr { } } + /// Splits the bytes into two at the given index. + /// + /// Afterwards `self` contains elements `[at, len)`, and the returned + /// `Bytes` contains elements `[0, at)`. + /// + /// This is an `O(1)` operation that just increases the reference count and + /// sets a few indices. + /// + /// + /// # Examples + /// + /// ``` + /// #[tokio::test] + /// async fn test_lmr_split() -> io::Result<()> { + /// let rdma = RdmaBuilder::default() + /// .set_port_num(1) + /// .set_gid_index(1) + /// .build()?; + /// let layout = Layout::new::<[u8; 4096]>(); + /// let mut lmr = rdma.alloc_local_mr(layout)?; + /// let start_addr = lmr.addr(); + /// let lmr_half = lmr.split_to(2048); + /// assert!(lmr_half.is_some()); + /// let lmr_half = lmr_half.unwrap(); + /// assert_eq!(lmr_half.length(), 2048); + /// assert_eq!(lmr_half.addr(), start_addr); + /// let lmr_overbound = lmr.split_to(2049); + /// assert!(lmr_overbound.is_none()); + /// Ok(()) + /// } + /// ``` + /// # Panics + /// + /// Panics if `at > len`. + #[inline] + pub fn split_to(&mut self, at: usize) -> Option { + // SAFETY: `self` is checked to be valid and in bounds above. + if at > self.length() { + None + } else { + let old_addr = self.addr; + self.addr = self.addr.wrapping_add(at); + self.len = self.len.wrapping_sub(at); + Some(Self { + inner: self.inner.clone(), + addr: old_addr, + len: at, + }) + } + } + /// Take the ownership and return an unchecked sub local mr from self /// /// # Safety diff --git a/src/queue_pair.rs b/src/queue_pair.rs index 8eabeb3..c7b4a6c 100644 --- a/src/queue_pair.rs +++ b/src/queue_pair.rs @@ -900,6 +900,20 @@ impl QueuePair { QueuePairOps::new(Arc::::clone(self), send, get_lmr_inners(lms)) } + /// submit send request of local memory regions, without waiting for completion + pub(crate) fn submit_send_sge( + self: &Arc, + lms: Vec, + imm: Option, + ) -> QueuePairOpsSubmit> + where + LR: LocalMrReadAccess + Unpin, + { + let inners = get_lmr_inners(&(lms.iter().map(|lm| lm).collect::>())); + let send = QPSendOwn::new(lms, imm); + QueuePairOpsSubmit::new(Arc::::clone(self), send, inners) + } + /// Send raw data #[cfg(feature = "raw")] pub(crate) async fn send_sge_raw<'a, LR>( @@ -1205,6 +1219,59 @@ where .map(|sz| debug!("post size: {sz}, mr len: {}", self.len)) } } + +/// Queue pair send operation +#[derive(Debug)] +pub(crate) struct QPSendOwn +where + LR: LocalMrReadAccess, +{ + /// local memory regions + lms: Vec, + /// length of data to send + len: usize, + /// Optionally, an immediate 4 byte value may be transmitted with the data buffer. + imm: Option, +} + +impl QPSendOwn +where + LR: LocalMrReadAccess, +{ + /// Create a new send operation from `lms` + fn new(lms: Vec, imm: Option) -> Self + where + LR: LocalMrReadAccess, + { + Self { + len: lms.iter().map(|lm| lm.length()).sum(), + lms, + imm, + } + } +} + +impl QueuePairOp for QPSendOwn +where + LR: LocalMrReadAccess, +{ + type Output = (); + + fn submit(&self, qp: &QueuePair, wr_id: WorkRequestId) -> io::Result<()> { + let lr_ref = self.lms.iter().map(|lm| lm).collect::>(); + qp.submit_send(&lr_ref, wr_id, self.imm) + } + + fn should_resubmit(&self, e: &io::Error) -> bool { + matches!(e.kind(), io::ErrorKind::OutOfMemory) + } + + fn result(&self, wc: WorkCompletion) -> Result { + wc.result() + .map(|sz| debug!("post size: {sz}, mr len: {}", self.len)) + } +} + /// Queue pair receive operation #[derive(Debug)] pub(crate) struct QPRecv<'lm, LW> @@ -1244,6 +1311,33 @@ where } } +#[derive(Debug)] +/// Queue pair operation submitted in wq, waitting for wc +pub(crate) struct QueuePairOpsInflight { + /// the operation + op: Op, + /// the work completion receiver + wc_rx: tokio::sync::mpsc::Receiver, +} + +impl QueuePairOpsInflight { + /// Create a new queue pair operation inflight + fn new(op: Op, wc_rx: tokio::sync::mpsc::Receiver) -> Self { + Self { op, wc_rx } + } + + /// Get the operation result + pub(crate) async fn result(&mut self) -> io::Result { + match self.wc_rx.recv().await { + Some(wc) => self.op.result(wc).map_err(Into::into), + None => Err(io::Error::new( + io::ErrorKind::Other, + "Wc receiver unexpect closed", + )), + } + } +} + /// Queue pair operation state #[derive(Debug)] enum QueuePairOpsState { @@ -1332,6 +1426,75 @@ impl Future for QueuePairOps { } } +/// Queue pair operation wrapper, return after `libv_post_send` +#[derive(Debug)] +pub(crate) struct QueuePairOpsSubmit { + /// the internal queue pair + qp: Arc, + /// operation state + state: QueuePairOpsState, + /// the operation + op: Option, +} + +impl QueuePairOpsSubmit { + /// Create a new queue `QueuePairOpsSubmit` wrapper + fn new(qp: Arc, op: Op, inners: LmrInners) -> Self { + Self { + qp, + state: QueuePairOpsState::Init(inners), + op: Some(op), + } + } +} + +impl Future for QueuePairOpsSubmit { + type Output = io::Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let s = self.get_mut(); + match s.state { + QueuePairOpsState::Init(ref inners) => { + let (wr_id, recv) = s.qp.cq_event_listener.register_for_write(inners)?; + s.state = QueuePairOpsState::Submit(wr_id, Some(recv)); + Pin::new(s).poll(cx) + } + QueuePairOpsState::Submit(wr_id, ref mut recv) => { + let op = s.op.as_mut().unwrap(); + if let Err(e) = op.submit(&s.qp, wr_id) { + if op.should_resubmit(&e) { + let sleep = Box::pin(sleep(RESUBMIT_DELAY)); + s.state = QueuePairOpsState::PendingToResubmit(sleep, wr_id, recv.take()); + } else { + tracing::error!("failed to submit the operation"); + // TODO: deregister wrid + return Poll::Ready(Err(e)); + } + } else { + match recv.take().ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "Bug in queue pair op poll") + }) { + Ok(recv) => { + return Poll::Ready(Ok(QueuePairOpsInflight::new( + s.op.take().unwrap(), + recv, + ))); + } + Err(e) => return Poll::Ready(Err(e)), + } + } + Pin::new(s).poll(cx) + } + QueuePairOpsState::PendingToResubmit(ref mut sleep, wr_id, ref mut recv) => { + ready!(sleep.poll_unpin(cx)); + s.state = QueuePairOpsState::Submit(wr_id, recv.take()); + Pin::new(s).poll(cx) + } + QueuePairOpsState::Submitted(_) => unreachable!(), + } + } +} + /// Builders to attributes helper pub(crate) fn builders_into_attrs( mut recv_attr_builder: RQAttrBuilder,