Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

s2n-tls-tokio: use s2n_shutdown_send instead of s2n_shutdown #4374

Merged
merged 4 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions bindings/rust/s2n-tls-tokio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,15 +363,19 @@ where
fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
ready!(self.as_mut().poll_blinding(ctx))?;

// s2n_shutdown must not be called again if it errors
// s2n_shutdown_send must not be called again if it errors
if self.shutdown_error.is_none() {
let result = ready!(self.as_mut().with_io(ctx, |mut context| {
context.conn.as_mut().poll_shutdown().map(|r| r.map(|_| ()))
context
.conn
.as_mut()
.poll_shutdown_send()
.map(|r| r.map(|_| ()))
}));
if let Err(error) = result {
self.shutdown_error = Some(error);
// s2n_shutdown reading might have triggered blinding again
ready!(self.as_mut().poll_blinding(ctx))?;
// s2n_shutdown_send only writes, so will never trigger blinding again.
// So we do not need to poll_blinding again after this error.
Comment on lines -373 to +378
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went back and forth on whether to leave the call to poll_blinding in or not. We could leave it in just in case we ever introduce a possible blinded error in s2n_shutdown_send, but that should never happen since we only blind when reading. In the end I removed it because I can't test it anymore without doing something crazy and undefined like calling s2n_recv from inside my mocked tcp write call. That's also why I deleted the "shutdown_with_blinding_bad_close_record" test-- I can't reasonably trigger that case anymore.

I left a comment as a reminder of the "double-blinding" problem just in case. But again, we should never read here, since it's the AsyncWrite trait :)

}
};

Expand Down
111 changes: 23 additions & 88 deletions bindings/rust/s2n-tls-tokio/tests/shutdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ use tokio::{

pub mod common;

// An arbitrary but very long timeout.
// No valid single IO operation should take anywhere near 10 minutes.
pub const LONG_TIMEOUT: time::Duration = time::Duration::from_secs(600);

async fn read_until_shutdown<S: AsyncRead + AsyncWrite + Unpin>(
stream: &mut TlsStream<S>,
) -> Result<(), std::io::Error> {
Expand Down Expand Up @@ -166,18 +162,6 @@ async fn shutdown_with_blinding() -> Result<(), Box<dyn std::error::Error>> {
let (mut client, mut server) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;

// Attempt to shutdown the client. This will eventually fail because the
// server has not written the close_notify message yet, but it will at least
// write the close_notify message that the server needs.
//
// Because this test begins paused and relies on auto-advancing, this does
// not actually require waiting LONG_TIMEOUT. See the tokio `pause()` docs:
// https://docs.rs/tokio/latest/tokio/time/fn.pause.html
//
// TODO: replace this with a half-close once the bindings support half-close.
let timeout = time::timeout(LONG_TIMEOUT, client.shutdown()).await;
assert!(timeout.is_err());

// Setup a bad record for the next read
overrides.next_read(Some(Box::new(|_, _, buf| {
// Parsing the header is one of the blinded operations
Expand All @@ -202,53 +186,9 @@ async fn shutdown_with_blinding() -> Result<(), Box<dyn std::error::Error>> {
// Server MUST eventually successfully shutdown
assert!(result.is_ok());

// Shutdown MUST have sent the close_notify message needed by the peer
// to also shutdown successfully.
client.shutdown().await?;

Ok(())
}

#[tokio::test(start_paused = true)]
async fn shutdown_with_blinding_bad_close_record() -> Result<(), Box<dyn std::error::Error>> {
let clock = common::TokioTime::default();
let mut server_config = common::server_config()?;
server_config.set_monotonic_clock(clock)?;

let client = TlsConnector::new(common::client_config()?.build()?);
let server = TlsAcceptor::new(server_config.build()?);

let (server_stream, client_stream) = common::get_streams().await?;
let server_stream = common::TestStream::new(server_stream);
let overrides = server_stream.overrides();
let (mut client, mut server) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;

// Setup a bad record for the next read
overrides.next_read(Some(Box::new(|_, _, buf| {
// Parsing the header is one of the blinded operations
// in s2n_shutdown, so provide a malformed header.
let zeroed_header = [23, 0, 0, 0, 0];
buf.put_slice(&zeroed_header);
Ok(()).into()
})));

let time_start = time::Instant::now();
let result = server.shutdown().await;
let time_elapsed = time_start.elapsed();

// Shutdown MUST NOT complete faster than minimal blinding time.
assert!(time_elapsed > common::MIN_BLINDING_SECS);

// Shutdown MUST eventually complete with the correct error after blinding.
let io_error = result.unwrap_err();
let error: error::Error = io_error.try_into()?;
assert!(error.kind() == error::ErrorType::ProtocolError);
assert!(error.name() == "S2N_ERR_BAD_MESSAGE");

// Shutdown MUST have sent the close_notify message needed by the peer
// to also shutdown successfully.
client.shutdown().await?;
// Shutdown MUST have sent the close_notify message needed for EOF.
let mut received = [0; 1];
assert!(client.read(&mut received).await? == 0);

Ok(())
}
Expand Down Expand Up @@ -295,7 +235,7 @@ async fn shutdown_with_poll_blinding() -> Result<(), Box<dyn std::error::Error>>
Ok(())
}

#[tokio::test(start_paused = true)]
#[tokio::test]
async fn shutdown_with_tcp_error() -> Result<(), Box<dyn std::error::Error>> {
let client = TlsConnector::new(common::client_config()?.build()?);
let server = TlsAcceptor::new(common::server_config()?.build()?);
Expand All @@ -304,20 +244,9 @@ async fn shutdown_with_tcp_error() -> Result<(), Box<dyn std::error::Error>> {
let server_stream = common::TestStream::new(server_stream);
let overrides = server_stream.overrides();

let (mut client, mut server) =
let (_, mut server) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;

// Attempt to shutdown the client. This will eventually fail because the
// server has not written the close_notify message yet, but it will at least
// write the close_notify message that the server needs.
//
// Because this test begins paused and relies on auto-advancing, this does
// not actually require waiting LONG_TIMEOUT. See the tokio `pause()` docs:
// https://docs.rs/tokio/latest/tokio/time/fn.pause.html
//
// TODO: replace this with a half-close once the bindings support half-close.
_ = time::timeout(time::Duration::from_secs(600), client.shutdown()).await;

// The underlying stream should return a unique error on shutdown
overrides.next_shutdown(Some(Box::new(|_, _| {
Ready(Err(io::Error::new(io::ErrorKind::Other, common::TEST_STR)))
Expand All @@ -343,22 +272,22 @@ async fn shutdown_with_tls_error_and_tcp_error() -> Result<(), Box<dyn std::erro
let (_, mut server) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;

// Both s2n_shutdown and the underlying stream should error on shutdown
overrides.next_read(Some(Box::new(|_, _, _| {
// Both s2n_shutdown_send and the underlying stream should error on shutdown
overrides.next_write(Some(Box::new(|_, _, _| {
Ready(Err(io::Error::from(io::ErrorKind::Other)))
})));
overrides.next_shutdown(Some(Box::new(|_, _| {
Ready(Err(io::Error::new(io::ErrorKind::Other, common::TEST_STR)))
})));

// Shutdown should complete with the correct error from s2n_shutdown
// Shutdown should complete with the correct error from s2n_shutdown_send
let result = server.shutdown().await;
let io_error = result.unwrap_err();
let error: error::Error = io_error.try_into()?;
// Any non-blocking read error is translated as "IOError"
assert!(error.kind() == error::ErrorType::IOError);

// Even if s2n_shutdown fails, we need to close the underlying stream.
// Even if s2n_shutdown_send fails, we need to close the underlying stream.
// Make sure we called our mock shutdown, consuming it.
assert!(overrides.is_consumed());

Expand All @@ -374,14 +303,11 @@ async fn shutdown_with_tls_error_and_tcp_delay() -> Result<(), Box<dyn std::erro
let server_stream = common::TestStream::new(server_stream);
let overrides = server_stream.overrides();

let (_, mut server) =
let (mut client, mut server) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;

// We want s2n_shutdown to fail on read in order to ensure that it is only
// called once on failure.
// If s2n_shutdown were called again, the second call would hang waiting
// for nonexistent input from the peer.
overrides.next_read(Some(Box::new(|_, _, _| {
// We want s2n_shutdown_send to produce an error on write
overrides.next_write(Some(Box::new(|_, _, _| {
Ready(Err(io::Error::from(io::ErrorKind::Other)))
})));

Expand All @@ -391,16 +317,25 @@ async fn shutdown_with_tls_error_and_tcp_delay() -> Result<(), Box<dyn std::erro
Pending
})));

// Shutdown should complete with the correct error from s2n_shutdown
// Shutdown should complete with the correct error from s2n_shutdown_send
let result = server.shutdown().await;
let io_error = result.unwrap_err();
let error: error::Error = io_error.try_into()?;
// Any non-blocking read error is translated as "IOError"
assert!(error.kind() == error::ErrorType::IOError);

// Even if s2n_shutdown fails, we need to close the underlying stream.
// Even if s2n_shutdown_send fails, we need to close the underlying stream.
// Make sure we at least called our mock shutdown, consuming it.
assert!(overrides.is_consumed());

maddeleine marked this conversation as resolved.
Show resolved Hide resolved
// Since s2n_shutdown_send failed, we should NOT have sent a close_notify.
// Make sure the peer doesn't receive a close_notify.
// If this is not true, then we're incorrectly calling s2n_shutdown_send
// again after an error.
let mut received = [0; 1];
let io_error = client.read(&mut received).await.unwrap_err();
let error: error::Error = io_error.try_into()?;
assert!(error.kind() == error::ErrorType::ConnectionClosed);

Ok(())
}
16 changes: 16 additions & 0 deletions bindings/rust/s2n-tls/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,22 @@ impl Connection {
}
}

/// Attempts a graceful shutdown of the write side of a TLS connection.
///
/// Unlike Self::poll_shutdown, no reponse from the peer is necessary.
/// If using TLS1.3, the connection can continue to be used for reading afterwards.
pub fn poll_shutdown_send(&mut self) -> Poll<Result<&mut Self, Error>> {
if !self.remaining_blinding_delay()?.is_zero() {
return Poll::Pending;
}
let mut blocked = s2n_blocked_status::NOT_BLOCKED;
unsafe {
s2n_shutdown_send(self.connection.as_ptr(), &mut blocked)
.into_poll()
.map_ok(|_| self)
}
}

/// Returns the TLS alert code, if any
pub fn alert(&self) -> Option<u8> {
let alert =
Expand Down
Loading