|
| 1 | +use std::io::{self, Read}; |
| 2 | + |
| 3 | +use alloc::collections::VecDeque; |
| 4 | +use bytes::{Buf, Bytes}; |
| 5 | +use super::{LenError, MessageLen}; |
| 6 | + |
| 7 | +/// Collect chunks of bytes until a full message is found |
| 8 | +pub struct MessageBuffer { |
| 9 | + /// Once parsing fails, it's impossible to recover |
| 10 | + fatal_error: bool, |
| 11 | + /// A rope for all incoming data |
| 12 | + chunks: VecDeque<Bytes>, |
| 13 | + /// Chunks are parsed lazily |
| 14 | + chunks_parsed_num: usize, |
| 15 | + /// Cumulative byte size of `chunks[..chunks_parsed_num]` |
| 16 | + chunks_parsed_byte_len: usize, |
| 17 | + /// msgpack parser |
| 18 | + msg_len: MessageLen, |
| 19 | +} |
| 20 | + |
| 21 | +/// Result after buffering chunk of data |
| 22 | +pub enum MaybeMessage { |
| 23 | + /// Found a complete message |
| 24 | + /// |
| 25 | + /// The message is split into pieces |
| 26 | + Message(MessageChunks), |
| 27 | + /// Message not complete yet. Read this many bytes. |
| 28 | + MoreBytes(usize), |
| 29 | +} |
| 30 | + |
| 31 | +/// This keeps individual `Bytes` pieces to avoid reallocating memory |
| 32 | +/// |
| 33 | +/// Use `into_inner` to process them manually, or use `MessageChunks` as `io::Read` |
| 34 | +pub struct MessageChunks(VecDeque<Bytes>); |
| 35 | + |
| 36 | +impl MessageChunks { |
| 37 | + /// Get the underlying `Bytes` |
| 38 | + pub fn into_inner(self) -> VecDeque<Bytes> { |
| 39 | + self.0 |
| 40 | + } |
| 41 | +} |
| 42 | + |
| 43 | +/// The exact `IntoIter` type may change in the future |
| 44 | +impl IntoIterator for MessageChunks { |
| 45 | + type IntoIter = <VecDeque<Bytes> as IntoIterator>::IntoIter; |
| 46 | + type Item = Bytes; |
| 47 | + fn into_iter(self) -> Self::IntoIter { |
| 48 | + self.0.into_iter() |
| 49 | + } |
| 50 | +} |
| 51 | + |
| 52 | +impl Read for MessageChunks { |
| 53 | + fn read(&mut self, out_buf: &mut [u8]) -> io::Result<usize> { |
| 54 | + while let Some(bytes) = self.0.get_mut(0) { |
| 55 | + let mut ch = bytes.chunk(); |
| 56 | + if ch.is_empty() { |
| 57 | + self.0.pop_front(); |
| 58 | + continue; |
| 59 | + } |
| 60 | + let read_len = out_buf.len().min(ch.len()); |
| 61 | + out_buf[..read_len].copy_from_slice(&ch[..read_len]); |
| 62 | + if read_len == ch.len() { |
| 63 | + self.0.pop_front(); |
| 64 | + } else { |
| 65 | + ch.advance(read_len); |
| 66 | + } |
| 67 | + return Ok(read_len); |
| 68 | + } |
| 69 | + Ok(0) |
| 70 | + } |
| 71 | + |
| 72 | + fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> { |
| 73 | + let len = self.0.iter().map(|ch| ch.remaining()).sum(); |
| 74 | + buf.try_reserve_exact(len).map_err(|_| io::ErrorKind::OutOfMemory)?; |
| 75 | + for c in self.0.drain(..) { |
| 76 | + buf.extend_from_slice(c.chunk()); |
| 77 | + } |
| 78 | + Ok(len) |
| 79 | + } |
| 80 | +} |
| 81 | + |
| 82 | +impl MessageBuffer { |
| 83 | + #[inline(always)] |
| 84 | + pub fn new() -> Self { |
| 85 | + Self { |
| 86 | + fatal_error: false, |
| 87 | + chunks_parsed_num: 0, |
| 88 | + chunks_parsed_byte_len: 0, |
| 89 | + chunks: VecDeque::new(), |
| 90 | + msg_len: MessageLen::new(), // TODO: limits |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + /// Parse chunks added with `push_bytes` etc., and dequeue chunks of complete msgpack messages |
| 95 | + pub fn poll_messages(&mut self) -> impl Iterator<Item = Result<MaybeMessage, ()>> + '_ { |
| 96 | + std::iter::from_fn(move || { |
| 97 | + while self.chunks_parsed_num < self.chunks.len() && !self.fatal_error { |
| 98 | + let bytes = &mut self.chunks[self.chunks_parsed_num]; |
| 99 | + self.chunks_parsed_num += 1; |
| 100 | + self.chunks_parsed_byte_len += bytes.len(); |
| 101 | + |
| 102 | + match self.msg_len.incremental_len(bytes.as_ref()) { |
| 103 | + Ok(message_len) => { |
| 104 | + self.msg_len.reset(); |
| 105 | + |
| 106 | + let unused_bytes = self.chunks_parsed_byte_len.saturating_sub(message_len); |
| 107 | + let remainder = bytes.split_off(bytes.len() - unused_bytes); |
| 108 | + |
| 109 | + // includes the `bytes` cut |
| 110 | + let message_data = self.chunks.drain(..self.chunks_parsed_num).collect::<VecDeque<_>>(); |
| 111 | + |
| 112 | + self.chunks_parsed_byte_len = 0; |
| 113 | + self.chunks_parsed_num = 0; |
| 114 | + self.chunks.push_front(remainder); |
| 115 | + |
| 116 | + debug_assert!(message_data.iter().all(|b| b.remaining() == b.len())); |
| 117 | + Some(Ok::<MaybeMessage, ()>(MaybeMessage::Message(MessageChunks(message_data)))); |
| 118 | + }, |
| 119 | + Err(LenError::Truncated(new_len)) => { |
| 120 | + if self.chunks_parsed_num >= self.chunks.len() { |
| 121 | + let wants_more = new_len.get().saturating_sub(self.chunks_parsed_byte_len); |
| 122 | + return Some(Ok(MaybeMessage::MoreBytes(wants_more))); |
| 123 | + } |
| 124 | + }, |
| 125 | + Err(LenError::ParseError) => { |
| 126 | + self.fatal_error = true; |
| 127 | + return Some(Err(())); |
| 128 | + }, |
| 129 | + } |
| 130 | + } |
| 131 | + None |
| 132 | + }) |
| 133 | + } |
| 134 | + |
| 135 | + /// Buffer more data |
| 136 | + pub fn push_bytes(&mut self, mut bytes: Bytes) { |
| 137 | + // bytes are stateful, and later `io::Read` will use that |
| 138 | + if bytes.remaining() != bytes.len() { |
| 139 | + bytes = bytes.slice(..); |
| 140 | + } |
| 141 | + self.chunks.push_back(bytes); |
| 142 | + } |
| 143 | + |
| 144 | + /// Buffer more data |
| 145 | + #[inline] |
| 146 | + pub fn push_vec(&mut self, bytes: Vec<u8>) { |
| 147 | + self.push_bytes(bytes.into()); |
| 148 | + } |
| 149 | + |
| 150 | + /// Buffer more data |
| 151 | + #[inline] |
| 152 | + pub fn copy_from_slice(&mut self, bytes: &[u8]) { |
| 153 | + self.push_bytes(Bytes::copy_from_slice(bytes)); |
| 154 | + } |
| 155 | + |
| 156 | + /// Recover buffered data |
| 157 | + pub fn into_bytes(self) -> Vec<Bytes> { |
| 158 | + self.chunks.into() |
| 159 | + } |
| 160 | +} |
| 161 | + |
0 commit comments