Skip to content

Commit

Permalink
bindings: fix handling of s2n_shutdown errors
Browse files Browse the repository at this point in the history
  • Loading branch information
lrstewart committed Jan 11, 2024
1 parent edebdae commit 1133ece
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 72 deletions.
103 changes: 34 additions & 69 deletions bindings/rust/s2n-tls-tokio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// SPDX-License-Identifier: Apache-2.0

use errno::{set_errno, Errno};
use pin_project_lite::pin_project;
use s2n_tls::{
config::Config,
connection::{Builder, Connection},
Expand All @@ -12,7 +11,7 @@ use s2n_tls::{
use std::{
fmt,
future::Future,
io, mem,
io,
os::raw::{c_int, c_void},
pin::Pin,
task::{
Expand Down Expand Up @@ -139,26 +138,15 @@ where
}
}

pin_project! {
struct BlindingState {
#[pin]
timer: Sleep,

// The remembered error if we got into blinding because of
// an error, or Ok(()) if we didn't. After returning the error,
// this goes back to Ok(()).
remembered_error: Result<(), Error>,
}
}

pub struct TlsStream<S, C = Connection>
where
C: AsRef<Connection> + AsMut<Connection> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
conn: C,
stream: S,
blinding: Option<Pin<Box<BlindingState>>>,
blinding: Option<Pin<Box<Sleep>>>,
shutdown_error: Option<Error>,
}

impl<S, C> TlsStream<S, C>
Expand All @@ -182,6 +170,7 @@ where
conn,
stream,
blinding: None,
shutdown_error: None,
};
TlsHandshake {
tls: &mut tls,
Expand Down Expand Up @@ -255,35 +244,6 @@ where
})
}

// Sets the blinding timer to the remaining blinding delay and possibly
// remembers an error.
//
// Returns the error if there was no blinding needed and the error
// did not need to be remembered.
fn set_blinding_timer(
self: Pin<&mut Self>,
mut remembered_error: Result<(), Error>,
) -> Result<(), Error> {
let tls = self.get_mut();

if tls.blinding.is_none() {
let delay = tls.as_ref().remaining_blinding_delay()?;
if !delay.is_zero() {
// Sleep operates at the milisecond resolution, so add an extra
// millisecond to account for any stray nanoseconds.
let safety = Duration::from_millis(1);
// Return the error *later*, after the blinding is done
let remembered_error = mem::replace(&mut remembered_error, Ok(()));
tls.blinding = Some(Box::pin(BlindingState {
timer: sleep(delay.saturating_add(safety)),
remembered_error,
}));
}
}

remembered_error
}

/// Polls the blinding timer, if there is any.
///
/// s2n has a "blinding" functionality - when a bad behavior from the peer
Expand All @@ -296,25 +256,24 @@ where
/// before dropping an s2n connection, you should wait until either
/// `poll_blinding` or `poll_shutdown` (which calls `poll_blinding`
/// internally) returns ready.
pub fn poll_blinding(
mut self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Result<(), Error>> {
self.as_mut().set_blinding_timer(Ok(()))?;

pub fn poll_blinding(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), Error>> {
let tls = self.get_mut();

if let Some(blinding) = &mut tls.blinding {
ready!(blinding.as_mut().project().timer.as_mut().poll(ctx));

// Set blinding to None to ensure the next go can have blinding
let mut blinding = tls.blinding.take().unwrap();
if tls.blinding.is_none() {
let delay = tls.as_ref().remaining_blinding_delay()?;
if !delay.is_zero() {
// Sleep operates at the milisecond resolution, so add an extra
// millisecond to account for any stray nanoseconds.
let safety = Duration::from_millis(1);
tls.blinding = Some(Box::pin(sleep(delay.saturating_add(safety))));
}
};

// If there is an error, return it
mem::replace(blinding.as_mut().project().remembered_error, Ok(()))?;
if let Some(timer) = tls.blinding.as_mut() {
ready!(timer.as_mut().poll(ctx));
tls.blinding = None;
}

// Otherwise we are OK
Poll::Ready(Ok(()))
}

Expand Down Expand Up @@ -404,19 +363,25 @@ where
fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
ready!(self.as_mut().poll_blinding(ctx))?;

let status = ready!(self.as_mut().with_io(ctx, |mut context| {
context.conn.as_mut().poll_shutdown().map(|r| r.map(|_| ()))
}));
// s2n_shutdown should not be called again if it errors
if self.shutdown_error.is_none() {
let result = ready!(self.as_mut().with_io(ctx, |mut context| {
context.conn.as_mut().poll_shutdown().map(|r| r.map(|_| ()))
}));
if let Err(error) = result {
self.shutdown_error = Some(error);
return self.poll_shutdown(ctx);
}
};

if let Err(e) = status {
// In case of an error shutting down, make sure you wait for
// the blinding timeout.
self.as_mut().set_blinding_timer(Err(e))?;
ready!(self.as_mut().poll_blinding(ctx))?;
unreachable!("should have returned the error we just put in!");
}
let tcp_result = ready!(Pin::new(&mut self.as_mut().stream).poll_shutdown(ctx));

Pin::new(&mut self.as_mut().stream).poll_shutdown(ctx)
let result = if let Some(error) = self.shutdown_error.take() {
Err(error).map_err(io::Error::from)
} else {
tcp_result
};
Ready(result)
}
}

Expand Down
2 changes: 2 additions & 0 deletions bindings/rust/s2n-tls-tokio/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ pub static RSA_KEY_PEM: &[u8] = include_bytes!(concat!(
pub const MIN_BLINDING_SECS: Duration = Duration::from_secs(10);
pub const MAX_BLINDING_SECS: Duration = Duration::from_secs(30);

pub static TEST_STR: &str = "hello world";

pub async fn get_streams() -> Result<(TcpStream, TcpStream), tokio::io::Error> {
let localhost = "127.0.0.1".to_owned();
let listener = TcpListener::bind(format!("{}:0", localhost)).await?;
Expand Down
32 changes: 30 additions & 2 deletions bindings/rust/s2n-tls-tokio/tests/common/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ use tokio::{

type ReadFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context, &mut ReadBuf) -> Poll<io::Result<()>>>;
type WriteFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context, &[u8]) -> Poll<io::Result<usize>>>;
type ShutdownFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context) -> Poll<io::Result<()>>>;

#[derive(Default)]
struct OverrideMethods {
next_read: Option<ReadFn>,
next_write: Option<WriteFn>,
next_shutdown: Option<ShutdownFn>,
}

#[derive(Default)]
Expand All @@ -36,6 +38,22 @@ impl Overrides {
overrides.next_write = input;
}
}

pub fn next_shutdown(&self, input: Option<ShutdownFn>) {
if let Ok(mut overrides) = self.0.lock() {
overrides.next_shutdown = input;
}
}

pub fn is_consumed(&self) -> bool {
if let Ok(overrides) = self.0.lock() {
overrides.next_read.is_none()
&& overrides.next_write.is_none()
&& overrides.next_shutdown.is_none()
} else {
false
}
}
}

unsafe impl Send for Overrides {}
Expand Down Expand Up @@ -100,7 +118,17 @@ impl AsyncWrite for TestStream {
Pin::new(&mut self.stream).poll_flush(ctx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_shutdown(ctx)
fn poll_shutdown(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
let s = self.get_mut();
let stream = Pin::new(&mut s.stream);
let action = match s.overrides.0.lock() {
Ok(mut overrides) => overrides.next_shutdown.take(),
_ => None,
};
if let Some(f) = action {
(f)(stream, ctx)
} else {
stream.poll_shutdown(ctx)
}
}
}
113 changes: 112 additions & 1 deletion bindings/rust/s2n-tls-tokio/tests/shutdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

use s2n_tls::error;
use s2n_tls_tokio::{TlsAcceptor, TlsConnector, TlsStream};
use std::{convert::TryFrom, sync::Arc};
use std::{
convert::TryFrom,
io,
sync::Arc,
task::Poll::{Pending, Ready},
};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
join, time,
Expand Down Expand Up @@ -281,3 +286,109 @@ async fn shutdown_with_poll_blinding() -> Result<(), Box<dyn std::error::Error>>

Ok(())
}

#[tokio::test(start_paused = true)]
async fn shutdown_with_tcp_error() -> Result<(), Box<dyn std::error::Error>> {
let client = TlsConnector::new(common::client_config()?.build()?);
let server = TlsAcceptor::new(common::server_config()?.build()?);

let (server_stream, client_stream) = common::get_streams().await?;
let server_stream = common::TestStream::new(server_stream);
let overrides = server_stream.overrides();

let (mut client, mut server) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;

// Attempt to shutdown the client. This will eventually fail because the
// server has not written the close_notify message yet, but it will at least
// write the close_notify message that the server needs.
// Because time is mocked for testing, this does not actually take 10 minutes.
// TODO: replace this with a half-close once the bindings support half-close.
_ = time::timeout(time::Duration::from_secs(600), client.shutdown()).await;

// The underlying stream should return a unique error on shutdown
overrides.next_shutdown(Some(Box::new(|_, _| {
Ready(Err(io::Error::new(io::ErrorKind::Other, common::TEST_STR)))
})));

// Shutdown should complete with the correct error from the underlying stream
let result = server.shutdown().await;
let error = result.unwrap_err().into_inner().unwrap();
assert!(error.to_string() == common::TEST_STR);

Ok(())
}

#[tokio::test]
async fn shutdown_with_tls_error_and_tcp_error() -> Result<(), Box<dyn std::error::Error>> {
let client = TlsConnector::new(common::client_config()?.build()?);
let server = TlsAcceptor::new(common::server_config()?.build()?);

let (server_stream, client_stream) = common::get_streams().await?;
let server_stream = common::TestStream::new(server_stream);
let overrides = server_stream.overrides();

let (_, mut server) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;

// Both s2n_shutdown and the underlying stream should error on shutdown
overrides.next_read(Some(Box::new(|_, _, _| {
Ready(Err(io::Error::from(io::ErrorKind::Other)))
})));
overrides.next_shutdown(Some(Box::new(|_, _| {
Ready(Err(io::Error::new(io::ErrorKind::Other, common::TEST_STR)))
})));

// Shutdown should complete with the correct error from s2n_shutdown
let result = server.shutdown().await;
let io_error = result.unwrap_err();
let error: error::Error = io_error.try_into()?;
// Any non-blocking read error is translated as "IOError"
assert!(error.kind() == error::ErrorType::IOError);

// Even if s2n_shutdown fails, we need to close the underlying stream.
// Make sure we called our mock shutdown, consuming it.
assert!(overrides.is_consumed());

Ok(())
}

#[tokio::test]
async fn shutdown_with_tls_error_and_tcp_delay() -> Result<(), Box<dyn std::error::Error>> {
let client = TlsConnector::new(common::client_config()?.build()?);
let server = TlsAcceptor::new(common::server_config()?.build()?);

let (server_stream, client_stream) = common::get_streams().await?;
let server_stream = common::TestStream::new(server_stream);
let overrides = server_stream.overrides();

let (_, mut server) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;

// We want s2n_shutdown to fail on read in order to ensure that it is only
// called once on failure.
// If s2n_shutdown were called again, the second call would hang waiting
// for nonexistent input from the peer.
overrides.next_read(Some(Box::new(|_, _, _| {
Ready(Err(io::Error::from(io::ErrorKind::Other)))
})));

// The underlying stream should initially return Pending, delaying shutdown
overrides.next_shutdown(Some(Box::new(|_, ctx| {
ctx.waker().wake_by_ref();
Pending
})));

// Shutdown should complete with the correct error from s2n_shutdown
let result = server.shutdown().await;
let io_error = result.unwrap_err();
let error: error::Error = io_error.try_into()?;
// Any non-blocking read error is translated as "IOError"
assert!(error.kind() == error::ErrorType::IOError);

// Even if s2n_shutdown fails, we need to close the underlying stream.
// Make sure we at least called our mock shutdown, consuming it.
assert!(overrides.is_consumed());

Ok(())
}

0 comments on commit 1133ece

Please sign in to comment.