Skip to content

Commit 9db4efd

Browse files
committed
Improved handling of the output buffer in read_all_timeout() function.
1 parent 140c8e9 commit 9db4efd

File tree

5 files changed

+67
-40
lines changed

5 files changed

+67
-40
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
[package]
55
name = "mtcp-rs"
6-
version = "0.1.8"
6+
version = "0.1.9"
77
edition = "2021"
88
license-file = "LICENSE"
99
description = "Provides a “blocking” implementation of TcpListener and TcpStream with proper timeout and cancellation support."

src/error.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ impl Debug for TcpError {
4343
Self::TimedOut => write!(f, "TcpError::TimedOut"),
4444
Self::Incomplete => write!(f, "TcpError::Incomplete"),
4545
Self::TooBig => write!(f, "TcpError::TooBig"),
46-
Self::Failed(error) => write!(f, "TcpError::Failed({:?})", error),
46+
Self::Failed(error) => write!(f, "TcpError::Failed({error})"),
4747
}
4848
}
4949
}
@@ -55,7 +55,7 @@ impl Display for TcpError {
5555
Self::TimedOut => write!(f, "The TCP socket operation timed out!"),
5656
Self::Incomplete => write!(f, "The TCP socket operation is incomplete!"),
5757
Self::TooBig => write!(f, "The TCP socket operation aborted, data is too big!"),
58-
Self::Failed(error) => write!(f, "{}", error),
58+
Self::Failed(error) => write!(f, "{error}"),
5959
}
6060
}
6161
}

src/stream.rs

Lines changed: 14 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use mio::net::TcpStream as MioTcpStream;
1414

1515
use log::warn;
1616

17-
use crate::utilities::Timeout;
17+
use crate::utilities::{Timeout, BufferManager};
1818
use crate::{TcpConnection, TcpManager, TcpError};
1919
use crate::manager::TcpPollContext;
2020

@@ -284,26 +284,25 @@ impl TcpStream {
284284
F: Fn(&[u8]) -> bool,
285285
{
286286
let chunk_size = chunk_size.unwrap_or_else(|| NonZeroUsize::new(4096).unwrap());
287-
let maximum_length = maximum_length.map(|value| NonZeroUsize::new(round_up(value.get(), chunk_size)).unwrap());
288-
let mut valid_length = buffer.len();
287+
if maximum_length.map_or(false, |value| value < chunk_size) {
288+
panic!("maximum_length must be greater than or equal to chunk_size!")
289+
}
290+
291+
let mut buffer = BufferManager::from(buffer, maximum_length);
289292

290293
loop {
291-
adjust_buffer(buffer, valid_length, chunk_size, maximum_length)?;
292-
let done = match self.read_timeout(&mut buffer[valid_length..], timeout) {
293-
Ok(0) => Some(Err(TcpError::Incomplete)),
294+
let spare = buffer.alloc_spare_buffer(chunk_size);
295+
match self.read_timeout(spare, timeout) {
296+
Ok(0) => return Err(TcpError::Incomplete),
294297
Ok(count) => {
295-
valid_length += count;
296-
match fn_complete(&buffer[..valid_length]) {
297-
true => Some(Ok(())),
298-
false => None,
298+
buffer.commit(count).map_err(|_err| TcpError::TooBig)?;
299+
match fn_complete(buffer.valid_data()) {
300+
true => return Ok(()),
301+
false => {},
299302
}
300303
},
301-
Err(error) => Some(Err(error)),
304+
Err(error) => return Err(error),
302305
};
303-
if let Some(result) = done {
304-
buffer.truncate(valid_length);
305-
return result;
306-
}
307306
}
308307
}
309308

@@ -445,25 +444,3 @@ fn into_io_result<T>(result: Result<T, TcpError>) -> IoResult<T> {
445444
Err(error) => Err(error.into()),
446445
}
447446
}
448-
449-
fn adjust_buffer(buffer: &mut Vec<u8>, length: usize, chunk_size: NonZeroUsize, maximum_length: Option<NonZeroUsize>) -> Result<(), TcpError> {
450-
let mut capacity = round_up(length, chunk_size);
451-
while capacity <= length {
452-
capacity = capacity.checked_add(chunk_size.get()).expect("Numerical overflow!")
453-
}
454-
if capacity > buffer.len() {
455-
if maximum_length.map_or(false, |max_len| capacity > max_len.get()) {
456-
return Err(TcpError::TooBig);
457-
}
458-
buffer.resize(capacity, 0);
459-
}
460-
Ok(())
461-
}
462-
463-
fn round_up(value: usize, block_size: NonZeroUsize) -> usize {
464-
let block_size = block_size.get();
465-
match value % block_size {
466-
0 => value,
467-
r => value.checked_add(block_size - r).expect("Numerical overflow!"),
468-
}
469-
}

src/utilities/buffer.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* mtcp - TcpListener/TcpStream *with* timeout/cancellation support
3+
* This is free and unencumbered software released into the public domain.
4+
*/
5+
use std::io::{Result as IoResult, Error as IoError, ErrorKind};
6+
use std::num::NonZeroUsize;
7+
use std::slice::from_raw_parts_mut;
8+
9+
pub struct BufferManager<'a> {
10+
buffer: &'a mut Vec<u8>,
11+
limit: usize,
12+
}
13+
14+
impl<'a> BufferManager<'a> {
15+
pub fn from(buffer: &'a mut Vec<u8>, limit: Option<NonZeroUsize>) -> Self {
16+
Self {
17+
buffer,
18+
limit: limit.map_or(usize::MAX, NonZeroUsize::get),
19+
}
20+
}
21+
22+
pub fn valid_data(&self) -> &[u8] {
23+
&self.buffer[..]
24+
}
25+
26+
pub fn alloc_spare_buffer(&mut self, min_length: NonZeroUsize) -> &'a mut[u8] {
27+
self.buffer.reserve(min_length.get());
28+
let spare = self.buffer.spare_capacity_mut();
29+
unsafe {
30+
from_raw_parts_mut(spare.as_mut_ptr() as *mut u8, spare.len())
31+
}
32+
}
33+
34+
pub fn commit(&mut self, additional: usize) -> IoResult<()> {
35+
if additional > 0 {
36+
let new_length = self.buffer.len().checked_add(additional).expect("Numerical overflow!");
37+
if new_length <= self.limit {
38+
assert!(new_length <= self.buffer.capacity());
39+
unsafe {
40+
self.buffer.set_len(new_length)
41+
}
42+
} else {
43+
return Err(IoError::new(ErrorKind::OutOfMemory, "New length exceeds the limit!"))
44+
}
45+
}
46+
Ok(())
47+
}
48+
}

src/utilities/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
* This is free and unencumbered software released into the public domain.
44
*/
55
mod flag;
6+
mod buffer;
67
mod timeout;
78

9+
pub(crate) use buffer::BufferManager;
810
pub(crate) use flag::Flag;
911
pub(crate) use timeout::Timeout;

0 commit comments

Comments
 (0)