diff --git a/Cargo.lock b/Cargo.lock index edd8d3237821..0d08e9ce77b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5122,6 +5122,7 @@ dependencies = [ "test-programs-artifacts", "tokio", "tokio-rustls", + "tracing", "wasmtime", "wasmtime-wasi", "webpki-roots", diff --git a/ci/vendor-wit.sh b/ci/vendor-wit.sh index 23fa1e4bba9f..4ec535223683 100755 --- a/ci/vendor-wit.sh +++ b/ci/vendor-wit.sh @@ -65,6 +65,10 @@ make_vendor "wasi-tls" " tls@v0.2.0-draft+505fc98 " +make_vendor "wasi-tls/src/p3" " + tls@v0.3.0-draft@wit-0.3.0-draft +" + make_vendor "wasi-config" "config@v0.2.0-rc.1" make_vendor "wasi-keyvalue" "keyvalue@219ea36" diff --git a/crates/test-programs/artifacts/build.rs b/crates/test-programs/artifacts/build.rs index 8279c447d2cb..20732a68e83a 100644 --- a/crates/test-programs/artifacts/build.rs +++ b/crates/test-programs/artifacts/build.rs @@ -74,19 +74,20 @@ impl Artifacts { // generates a `foreach_*` macro below. let kind = match test.name.as_str() { s if s.starts_with("p1_") => "p1", - s if s.starts_with("p2_http_") => "p2_http", - s if s.starts_with("p2_cli_") => "p2_cli", s if s.starts_with("p2_api_") => "p2_api", + s if s.starts_with("p2_cli_") => "p2_cli", + s if s.starts_with("p2_http_") => "p2_http", + s if s.starts_with("p2_tls_") => "p2_tls", s if s.starts_with("p2_") => "p2", s if s.starts_with("nn_") => "nn", s if s.starts_with("piped_") => "piped", s if s.starts_with("dwarf_") => "dwarf", s if s.starts_with("config_") => "config", s if s.starts_with("keyvalue_") => "keyvalue", - s if s.starts_with("tls_") => "tls", s if s.starts_with("async_") => "async", - s if s.starts_with("p3_http_") => "p3_http", s if s.starts_with("p3_api_") => "p3_api", + s if s.starts_with("p3_http_") => "p3_http", + s if s.starts_with("p3_tls_") => "p3_tls", s if s.starts_with("p3_") => "p3", s if s.starts_with("fuzz_") => "fuzz", // If you're reading this because you hit this panic, either add diff --git a/crates/test-programs/src/bin/tls_sample_application.rs b/crates/test-programs/src/bin/p2_tls_sample_application.rs similarity index 100% rename from crates/test-programs/src/bin/tls_sample_application.rs rename to crates/test-programs/src/bin/p2_tls_sample_application.rs diff --git a/crates/test-programs/src/bin/p3_tls_sample_application.rs b/crates/test-programs/src/bin/p3_tls_sample_application.rs new file mode 100644 index 000000000000..4ecfbe4e71a4 --- /dev/null +++ b/crates/test-programs/src/bin/p3_tls_sample_application.rs @@ -0,0 +1,168 @@ +use anyhow::{Context as _, Result, anyhow, bail}; +use core::future::{Future as _, poll_fn}; +use core::pin::pin; +use core::str; +use core::task::{Poll, ready}; +use futures::try_join; +use test_programs::p3::wasi::sockets::ip_name_lookup::resolve_addresses; +use test_programs::p3::wasi::sockets::types::{IpAddress, IpSocketAddress, TcpSocket}; +use test_programs::p3::wasi::tls; +use test_programs::p3::wasi::tls::client::Hello; +use test_programs::p3::wit_stream; +use wit_bindgen::StreamResult; + +struct Component; + +test_programs::p3::export!(Component); + +const PORT: u16 = 443; + +async fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> { + let request = format!( + "GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\nConnection: close\r\n\r\n" + ); + + let sock = TcpSocket::create(ip.family()).unwrap(); + sock.connect(IpSocketAddress::new(ip, PORT)) + .await + .context("tcp connect failed")?; + + let (sock_rx, sock_rx_fut) = sock.receive(); + let hello = Hello::new(); + hello + .set_server_name(domain) + .map_err(|()| anyhow!("failed to set SNI"))?; + let (sock_tx, conn) = tls::client::connect(hello, sock_rx); + let sock_tx_fut = sock.send(sock_tx); + + let mut conn = pin!(conn.into_future()); + let mut sock_rx_fut = pin!(sock_rx_fut.into_future()); + let mut sock_tx_fut = pin!(sock_tx_fut); + let conn = poll_fn(|cx| match conn.as_mut().poll(cx) { + Poll::Ready(Ok(conn)) => Poll::Ready(Ok(conn)), + Poll::Ready(Err(())) => Poll::Ready(Err(anyhow!("tls handshake failed"))), + Poll::Pending => match sock_tx_fut.as_mut().poll(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Err(anyhow!("Tx stream closed unexpectedly"))), + Poll::Ready(Err(err)) => { + Poll::Ready(Err(anyhow!("Tx stream closed with error: {err:?}"))) + } + Poll::Pending => match ready!(sock_rx_fut.as_mut().poll(cx)) { + Ok(_) => Poll::Ready(Err(anyhow!("Rx stream closed unexpectedly"))), + Err(err) => Poll::Ready(Err(anyhow!("Rx stream closed with error: {err:?}"))), + }, + }, + }) + .await?; + + let (mut req_tx, req_rx) = wit_stream::new(); + let (mut res_rx, result_fut) = tls::client::Handshake::finish(conn, req_rx); + + let res = Vec::with_capacity(8192); + try_join!( + async { + let buf = req_tx.write_all(request.into()).await; + assert_eq!(buf, []); + drop(req_tx); + Ok(()) + }, + async { + let (result, buf) = res_rx.read(res).await; + match result { + StreamResult::Complete(..) => { + drop(res_rx); + let res = String::from_utf8(buf)?; + if res.contains("HTTP/1.1 200 OK") { + Ok(()) + } else { + bail!("server did not respond with 200 OK: {res}") + } + } + StreamResult::Dropped => bail!("read dropped"), + StreamResult::Cancelled => bail!("read cancelled"), + } + }, + async { result_fut.await.map_err(|()| anyhow!("TLS session failed")) }, + async { sock_rx_fut.await.context("TCP receipt failed") }, + async { sock_tx_fut.await.context("TCP transmit failed") }, + )?; + Ok(()) +} + +/// This test sets up a TCP connection using one domain, and then attempts to +/// perform a TLS handshake using another unrelated domain. This should result +/// in a handshake error. +async fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> { + const BAD_DOMAIN: &'static str = "wrongdomain.localhost"; + + let sock = TcpSocket::create(ip.family()).unwrap(); + sock.connect(IpSocketAddress::new(ip, PORT)) + .await + .context("tcp connect failed")?; + + let (sock_rx, sock_rx_fut) = sock.receive(); + let hello = Hello::new(); + hello + .set_server_name(BAD_DOMAIN) + .map_err(|()| anyhow!("failed to set SNI"))?; + let (sock_tx, conn) = tls::client::connect(hello, sock_rx); + let sock_tx_fut = sock.send(sock_tx); + + try_join!( + async { + match conn.await { + Err(()) => Ok(()), + Ok(_) => panic!("expecting server name mismatch"), + } + }, + async { sock_rx_fut.await.context("TCP receipt failed") }, + async { sock_tx_fut.await.context("TCP transmit failed") }, + )?; + Ok(()) +} + +async fn try_live_endpoints<'a, Fut>(test: impl Fn(&'a str, IpAddress) -> Fut) +where + Fut: Future> + 'a, +{ + // since this is testing remote endpoints to ensure system cert store works + // the test uses a couple different endpoints to reduce the number of flakes + const DOMAINS: &'static [&'static str] = &[ + "example.com", + "api.github.com", + "docs.wasmtime.dev", + "bytecodealliance.org", + "www.rust-lang.org", + ]; + + for &domain in DOMAINS { + let result = (|| async { + let ip = resolve_addresses(domain.into()) + .await? + .first() + .map(|a| a.to_owned()) + .ok_or_else(|| anyhow!("DNS lookup failed."))?; + test(&domain, ip).await + })(); + + match result.await { + Ok(()) => return, + Err(e) => { + eprintln!("test for {domain} failed: {e:#}"); + } + } + } + + panic!("all tests failed"); +} + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + println!("sample app"); + try_live_endpoints(test_tls_sample_application).await; + println!("invalid cert"); + try_live_endpoints(test_tls_invalid_certificate).await; + Ok(()) + } +} + +fn main() {} diff --git a/crates/test-programs/src/p3/mod.rs b/crates/test-programs/src/p3/mod.rs index 87ddddb2d7cf..9e9efa7a7511 100644 --- a/crates/test-programs/src/p3/mod.rs +++ b/crates/test-programs/src/p3/mod.rs @@ -8,11 +8,15 @@ wit_bindgen::generate!({ world testp3 { include wasi:cli/imports@0.3.0-rc-2025-09-16; include wasi:http/imports@0.3.0-rc-2025-09-16; + include wasi:tls/imports@0.3.0-draft; export wasi:cli/run@0.3.0-rc-2025-09-16; } ", - path: "../wasi-http/src/p3/wit", + path: [ + "../wasi-http/src/p3/wit", + "../wasi-tls/src/p3/wit", + ], world: "wasmtime:test/testp3", default_bindings_module: "test_programs::p3", pub_export_macro: true, diff --git a/crates/wasi-tls-nativetls/tests/main.rs b/crates/wasi-tls-nativetls/tests/main.rs index 202a2cd34c4a..894056a78988 100644 --- a/crates/wasi-tls-nativetls/tests/main.rs +++ b/crates/wasi-tls-nativetls/tests/main.rs @@ -62,9 +62,9 @@ macro_rules! assert_test_exists { }; } -test_programs_artifacts::foreach_tls!(assert_test_exists); +test_programs_artifacts::foreach_p2_tls!(assert_test_exists); #[tokio::test(flavor = "multi_thread")] -async fn tls_sample_application() -> Result<()> { - run_test(test_programs_artifacts::TLS_SAMPLE_APPLICATION_COMPONENT).await +async fn p2_tls_sample_application() -> Result<()> { + run_test(test_programs_artifacts::P2_TLS_SAMPLE_APPLICATION_COMPONENT).await } diff --git a/crates/wasi-tls/Cargo.toml b/crates/wasi-tls/Cargo.toml index be715c5b6d1b..dfcc3479bb6d 100644 --- a/crates/wasi-tls/Cargo.toml +++ b/crates/wasi-tls/Cargo.toml @@ -11,6 +11,10 @@ description = "Wasmtime implementation of the wasi-tls API" [lints] workspace = true +[features] +default = [] +p3 = ["wasmtime-wasi/p3", "wasmtime/component-model-async"] + [dependencies] anyhow = { workspace = true } bytes = { workspace = true } @@ -20,6 +24,7 @@ tokio = { workspace = true, features = [ "time", "io-util", ] } +tracing = { workspace = true } wasmtime = { workspace = true, features = ["runtime", "component-model"] } wasmtime-wasi = { workspace = true } diff --git a/crates/wasi-tls/src/lib.rs b/crates/wasi-tls/src/lib.rs index f98cdf186b4c..214675dcac07 100644 --- a/crates/wasi-tls/src/lib.rs +++ b/crates/wasi-tls/src/lib.rs @@ -76,6 +76,8 @@ use wasmtime::component::{HasData, ResourceTable}; pub mod bindings; mod host; mod io; +#[cfg(feature = "p3")] +pub mod p3; mod rustls; pub use bindings::types::LinkOptions; diff --git a/crates/wasi-tls/src/p3/bindings.rs b/crates/wasi-tls/src/p3/bindings.rs new file mode 100644 index 000000000000..5dc68fb01be7 --- /dev/null +++ b/crates/wasi-tls/src/p3/bindings.rs @@ -0,0 +1,23 @@ +//! Raw bindings to the `wasi:tls` package. + +#[expect(missing_docs, reason = "generated code")] +mod generated { + wasmtime::component::bindgen!({ + path: "src/p3/wit", + world: "wasi:tls/imports", + imports: { + "wasi:tls/client.[static]handshake.finish": trappable | tracing | store, + "wasi:tls/client.connect": trappable | tracing | store, + "wasi:tls/server.[static]handshake.finish": trappable | tracing | store, + default: trappable | tracing + }, + with: { + "wasi:tls/client.handshake": crate::p3::ClientHandshake, + "wasi:tls/client.hello": crate::p3::ClientHello, + "wasi:tls/server.handshake": crate::p3::ServerHandshake, + "wasi:tls/types.certificate": crate::p3::Certificate, + }, + }); +} + +pub use self::generated::wasi::*; diff --git a/crates/wasi-tls/src/p3/host/client.rs b/crates/wasi-tls/src/p3/host/client.rs new file mode 100644 index 000000000000..440bd6dd5a67 --- /dev/null +++ b/crates/wasi-tls/src/p3/host/client.rs @@ -0,0 +1,259 @@ +use super::{ + CiphertextConsumer, CiphertextProducer, PlaintextConsumer, PlaintextProducer, ResultProducer, + mk_delete, mk_get, mk_get_mut, mk_push, +}; +use crate::p3::bindings::tls::client::{ + Handshake, Hello, Host, HostHandshake, HostHandshakeWithStore, HostHello, HostWithStore, +}; +use crate::p3::bindings::tls::types::Certificate; +use crate::p3::{TlsStream, TlsStreamClientArc, WasiTls, WasiTlsCtxView}; +use anyhow::{Context as _, anyhow, bail}; +use core::mem; +use core::net::{IpAddr, Ipv4Addr}; +use core::pin::{Pin, pin}; +use core::task::{Context, Poll}; +use rustls::client::ResolvesClientCert; +use rustls::pki_types::ServerName; +use std::sync::{Arc, Mutex}; +use tokio::sync::oneshot; +use wasmtime::StoreContextMut; +use wasmtime::component::{Access, FutureProducer, FutureReader, Resource, StreamReader}; + +mk_push!(Hello, push_hello, "client hello"); +mk_get_mut!(Hello, get_hello_mut, "client hello"); +mk_delete!(Hello, delete_hello, "client hello"); + +mk_push!(Handshake, push_handshake, "client handshake"); +mk_get!(Handshake, get_handshake, "client handshake"); +mk_delete!(Handshake, delete_handshake, "client handshake"); + +#[derive(Default)] +enum ConnectProducer { + Pending { + stream: TlsStreamClientArc, + error_rx: oneshot::Receiver, + getter: fn(&mut T) -> WasiTlsCtxView<'_>, + }, + #[default] + Exhausted, +} + +impl FutureProducer for ConnectProducer +where + D: 'static, +{ + type Item = Result, ()>; + + fn poll_produce( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut store: StoreContextMut, + finish: bool, + ) -> Poll>> { + let this = self.get_mut(); + let Self::Pending { + stream, + mut error_rx, + getter, + } = mem::take(this) + else { + return Poll::Ready(Err(anyhow!("polled after ready"))); + }; + if let Poll::Ready(..) = pin!(&mut error_rx).poll(cx) { + return Poll::Ready(Ok(Some(Err(())))); + } + + { + let mut stream_lock = stream.lock(); + let TlsStream { conn, read_tls, .. } = stream_lock.as_deref_mut().unwrap(); + if conn.peer_certificates().is_none() || conn.negotiated_cipher_suite().is_none() { + if !finish { + *read_tls = Some(cx.waker().clone()); + } + drop(stream_lock); + *this = Self::Pending { + stream, + error_rx, + getter, + }; + if finish { + return Poll::Ready(Ok(None)); + } + return Poll::Pending; + } + }; + + let WasiTlsCtxView { table, .. } = getter(store.data_mut()); + + let handshake = Handshake { stream, error_rx }; + let handshake = push_handshake(table, handshake)?; + + Poll::Ready(Ok(Some(Ok(handshake)))) + } +} + +#[derive(Debug)] +struct CertificateResolver; + +impl ResolvesClientCert for CertificateResolver { + fn resolve( + &self, + _root_hint_subjects: &[&[u8]], + _sigschemes: &[rustls::SignatureScheme], + ) -> Option> { + // TODO: implement + None + } + + fn has_certs(&self) -> bool { + false + } +} + +impl Host for WasiTlsCtxView<'_> {} + +impl HostHello for WasiTlsCtxView<'_> { + fn new(&mut self) -> wasmtime::Result> { + push_hello(&mut self.table, Hello::default()) + } + + fn set_server_name( + &mut self, + hello: Resource, + server_name: String, + ) -> wasmtime::Result> { + let hello = get_hello_mut(&mut self.table, &hello)?; + let Ok(server_name) = server_name.try_into() else { + return Ok(Err(())); + }; + hello.server_name = Some(server_name); + Ok(Ok(())) + } + + fn set_alpn_ids( + &mut self, + hello: Resource, + alpn_ids: Vec>, + ) -> wasmtime::Result<()> { + let hello = get_hello_mut(&mut self.table, &hello)?; + hello.alpn_ids = Some(alpn_ids); + Ok(()) + } + + fn set_cipher_suites( + &mut self, + hello: Resource, + cipher_suites: Vec, + ) -> wasmtime::Result<()> { + let hello = get_hello_mut(&mut self.table, &hello)?; + hello.cipher_suites = cipher_suites; + Ok(()) + } + + fn drop(&mut self, hello: Resource) -> wasmtime::Result<()> { + delete_hello(&mut self.table, hello)?; + Ok(()) + } +} + +impl HostWithStore for WasiTls { + fn connect( + mut store: Access, + hello: Resource, + incoming: StreamReader, + ) -> wasmtime::Result<( + StreamReader, + FutureReader, ()>>, + )> { + let Hello { + server_name, + alpn_ids, + cipher_suites, + } = delete_hello(store.get().table, hello)?; + + let roots = rustls::RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.into(), + }; + if !cipher_suites.is_empty() { + // TODO: implement + bail!("custom cipher suites not supported yet") + } + let mut config = rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_client_cert_resolver(Arc::new(CertificateResolver)); + if let Some(alpn_ids) = alpn_ids { + config.alpn_protocols = alpn_ids; + } + let server_name = if let Some(server_name) = server_name { + server_name + } else { + config.enable_sni = false; + ServerName::IpAddress(IpAddr::V4(Ipv4Addr::UNSPECIFIED).into()) + }; + let conn = rustls::ClientConnection::new(Arc::from(config), server_name) + .context("failed to construct rustls client connection")?; + let (error_tx, error_rx) = oneshot::channel(); + let stream = Arc::new(Mutex::new(TlsStream::new(conn, error_tx))); + + incoming.pipe(&mut store, CiphertextConsumer(Arc::clone(&stream))); + let getter = store.getter(); + Ok(( + StreamReader::new(&mut store, CiphertextProducer(Arc::clone(&stream))), + FutureReader::new( + &mut store, + ConnectProducer::Pending { + stream, + error_rx, + getter, + }, + ), + )) + } +} + +impl HostHandshake for WasiTlsCtxView<'_> { + fn set_client_certificate( + &mut self, + _handshake: Resource, + _cert: Resource, + ) -> wasmtime::Result<()> { + todo!() + } + + fn get_server_certificate( + &mut self, + _handshake: Resource, + ) -> wasmtime::Result>> { + todo!() + } + + fn get_cipher_suite(&mut self, handshake: Resource) -> wasmtime::Result { + let Handshake { stream, .. } = get_handshake(&self.table, &handshake)?; + let mut stream = stream.lock(); + let TlsStream { conn, .. } = stream.as_deref_mut().unwrap(); + let cipher_suite = conn + .negotiated_cipher_suite() + .context("cipher suite not available")?; + Ok(cipher_suite.suite().get_u16()) + } + + fn drop(&mut self, handshake: Resource) -> wasmtime::Result<()> { + delete_handshake(&mut self.table, handshake)?; + Ok(()) + } +} + +impl HostHandshakeWithStore for WasiTls { + fn finish( + mut store: Access, + handshake: Resource, + data: StreamReader, + ) -> wasmtime::Result<(StreamReader, FutureReader>)> { + let Handshake { stream, error_rx } = delete_handshake(&mut store.get().table, handshake)?; + data.pipe(&mut store, PlaintextConsumer(Arc::clone(&stream))); + Ok(( + StreamReader::new(&mut store, PlaintextProducer(stream)), + FutureReader::new(&mut store, ResultProducer(error_rx)), + )) + } +} diff --git a/crates/wasi-tls/src/p3/host/mod.rs b/crates/wasi-tls/src/p3/host/mod.rs new file mode 100644 index 000000000000..2b48c7f0b03e --- /dev/null +++ b/crates/wasi-tls/src/p3/host/mod.rs @@ -0,0 +1,364 @@ +use crate::p3::{TlsStream, TlsStreamArc}; +use anyhow::Context as _; +use core::ops::DerefMut; +use core::pin::Pin; +use core::task::{Context, Poll, Waker}; +use std::io::{Read as _, Write as _}; +use tokio::sync::oneshot; +use wasmtime::StoreContextMut; +use wasmtime::component::{ + Destination, FutureProducer, Source, StreamConsumer, StreamProducer, StreamResult, +}; + +mod client; +mod server; +mod types; + +macro_rules! mk_push { + ($t:ty, $f:ident, $desc:literal) => { + #[track_caller] + #[inline] + pub fn $f( + table: &mut wasmtime::component::ResourceTable, + value: $t, + ) -> wasmtime::Result> { + use anyhow::Context as _; + + table + .push(value) + .context(concat!("failed to push ", $desc, " resource to table")) + } + }; +} + +macro_rules! mk_get { + ($t:ty, $f:ident, $desc:literal) => { + #[track_caller] + #[inline] + pub fn $f<'a>( + table: &'a wasmtime::component::ResourceTable, + resource: &'a wasmtime::component::Resource<$t>, + ) -> wasmtime::Result<&'a $t> { + use anyhow::Context as _; + + table + .get(resource) + .context(concat!("failed to get ", $desc, " resource from table")) + } + }; +} + +macro_rules! mk_get_mut { + ($t:ty, $f:ident, $desc:literal) => { + #[track_caller] + #[inline] + pub fn $f<'a>( + table: &'a mut wasmtime::component::ResourceTable, + resource: &'a wasmtime::component::Resource<$t>, + ) -> wasmtime::Result<&'a mut $t> { + use anyhow::Context as _; + + table.get_mut(resource).context(concat!( + "failed to get ", + $desc, + " resource from table" + )) + } + }; +} + +macro_rules! mk_delete { + ($t:ty, $f:ident, $desc:literal) => { + #[track_caller] + #[inline] + pub fn $f( + table: &mut wasmtime::component::ResourceTable, + resource: wasmtime::component::Resource<$t>, + ) -> wasmtime::Result<$t> { + use anyhow::Context as _; + + table.delete(resource).context(concat!( + "failed to delete ", + $desc, + " resource from table" + )) + } + }; +} + +pub(crate) use {mk_delete, mk_get, mk_get_mut, mk_push}; + +struct CiphertextConsumer(TlsStreamArc); + +impl StreamConsumer for CiphertextConsumer +where + T: DerefMut> + Send + 'static, +{ + type Item = u8; + + fn poll_consume( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + store: StoreContextMut, + src: Source, + finish: bool, + ) -> Poll> { + let mut stream = self.0.lock(); + let TlsStream { + conn, + error_tx, + read_tls, + ciphertext_consumer, + ciphertext_producer, + plaintext_consumer, + plaintext_producer, + .. + } = stream.as_deref_mut().unwrap(); + if error_tx.is_none() { + return Poll::Ready(Ok(StreamResult::Dropped)); + } + + if !conn.wants_read() { + if finish { + return Poll::Ready(Ok(StreamResult::Cancelled)); + } + *ciphertext_consumer = Some(cx.waker().clone()); + return Poll::Pending; + } + + let mut src = src.as_direct(store); + if src.remaining().is_empty() { + return Poll::Ready(Ok(StreamResult::Completed)); + } + let n = conn.read_tls(&mut src)?; + debug_assert_ne!(n, 0); + read_tls.take().map(Waker::wake); + + let state = match conn.process_new_packets() { + Ok(state) => state, + Err(err) => { + _ = error_tx.take().unwrap().send(err); + ciphertext_producer.take().map(Waker::wake); + plaintext_consumer.take().map(Waker::wake); + plaintext_producer.take().map(Waker::wake); + return Poll::Ready(Ok(StreamResult::Dropped)); + } + }; + + if state.plaintext_bytes_to_read() > 0 { + plaintext_producer.take().map(Waker::wake); + } + + if state.tls_bytes_to_write() > 0 { + ciphertext_producer.take().map(Waker::wake); + } + + if state.peer_has_closed() { + // even if there are no bytes to read, the producer may be pending + plaintext_producer.take().map(Waker::wake); + return Poll::Ready(Ok(StreamResult::Dropped)); + } + + Poll::Ready(Ok(StreamResult::Completed)) + } +} + +struct PlaintextProducer(TlsStreamArc); + +impl StreamProducer for PlaintextProducer +where + T: DerefMut> + Send + 'static, +{ + type Item = u8; + type Buffer = Option; // unused + + fn poll_produce<'a>( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + store: StoreContextMut<'a, D>, + dst: Destination<'a, Self::Item, Self::Buffer>, + finish: bool, + ) -> Poll> { + let mut stream = self.0.lock(); + let TlsStream { + conn, + error_tx, + ciphertext_consumer, + plaintext_producer, + .. + } = stream.as_deref_mut().unwrap(); + if error_tx.is_none() { + return Poll::Ready(Ok(StreamResult::Dropped)); + } + + let state = conn.process_new_packets().context("unhandled TLS error")?; + if state.plaintext_bytes_to_read() == 0 { + if state.peer_has_closed() { + return Poll::Ready(Ok(StreamResult::Dropped)); + } else if finish { + return Poll::Ready(Ok(StreamResult::Cancelled)); + } + *plaintext_producer = Some(cx.waker().clone()); + return Poll::Pending; + } + + let mut dst = dst.as_direct(store, state.plaintext_bytes_to_read()); + let buf = dst.remaining(); + if buf.is_empty() { + return Poll::Ready(Ok(StreamResult::Completed)); + } + let n = conn.reader().read(buf)?; + debug_assert_ne!(n, 0); + dst.mark_written(n); + if conn.wants_read() { + ciphertext_consumer.take().map(Waker::wake); + } + Poll::Ready(Ok(StreamResult::Completed)) + } +} + +struct PlaintextConsumer(TlsStreamArc) +where + T: DerefMut> + Send + 'static; + +impl Drop for PlaintextConsumer +where + T: DerefMut> + Send + 'static, +{ + fn drop(&mut self) { + let mut stream = self.0.lock(); + let TlsStream { + conn, + close_notify, + ciphertext_producer, + .. + } = stream.as_deref_mut().unwrap(); + conn.send_close_notify(); + *close_notify = true; + ciphertext_producer.take().map(Waker::wake); + } +} + +impl StreamConsumer for PlaintextConsumer +where + T: DerefMut> + Send + 'static, + U: 'static, +{ + type Item = u8; + + fn poll_consume( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + store: StoreContextMut, + src: Source, + finish: bool, + ) -> Poll> { + let mut stream = self.0.lock(); + let TlsStream { + conn, + error_tx, + ciphertext_producer, + plaintext_consumer, + .. + } = stream.as_deref_mut().unwrap(); + if error_tx.is_none() { + return Poll::Ready(Ok(StreamResult::Dropped)); + } + + let mut src = src.as_direct(store); + if src.remaining().is_empty() { + return Poll::Ready(Ok(StreamResult::Completed)); + } + + let mut dst = conn.writer(); + let n = dst.write(src.remaining())?; + if n == 0 { + if finish { + return Poll::Ready(Ok(StreamResult::Cancelled)); + } + *plaintext_consumer = Some(cx.waker().clone()); + return Poll::Pending; + } + src.mark_read(n); + dst.flush()?; + if conn.wants_write() { + ciphertext_producer.take().map(Waker::wake); + } + Poll::Ready(Ok(StreamResult::Completed)) + } +} + +struct CiphertextProducer(TlsStreamArc); + +impl StreamProducer for CiphertextProducer +where + T: DerefMut> + Send + 'static, +{ + type Item = u8; + type Buffer = Option; // unused + + fn poll_produce<'a>( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + store: StoreContextMut<'a, D>, + dst: Destination<'a, Self::Item, Self::Buffer>, + finish: bool, + ) -> Poll> { + let mut stream = self.0.lock(); + let TlsStream { + conn, + error_tx, + close_notify, + ciphertext_consumer, + ciphertext_producer, + plaintext_consumer, + .. + } = stream.as_deref_mut().unwrap(); + if error_tx.is_none() { + return Poll::Ready(Ok(StreamResult::Dropped)); + } + + if !conn.wants_write() { + if *close_notify { + return Poll::Ready(Ok(StreamResult::Dropped)); + } else if finish { + return Poll::Ready(Ok(StreamResult::Cancelled)); + } + *ciphertext_producer = Some(cx.waker().clone()); + plaintext_consumer.take().map(Waker::wake); + return Poll::Pending; + } + + let state = conn.process_new_packets().context("unhandled TLS error")?; + let mut dst = dst.as_direct(store, state.tls_bytes_to_write()); + if dst.remaining().is_empty() { + return Poll::Ready(Ok(StreamResult::Completed)); + } + let n = conn.write_tls(&mut dst)?; + debug_assert_ne!(n, 0); + if conn.wants_read() { + ciphertext_consumer.take().map(Waker::wake); + } + Poll::Ready(Ok(StreamResult::Completed)) + } +} + +struct ResultProducer(oneshot::Receiver); + +impl FutureProducer for ResultProducer { + type Item = Result<(), ()>; + + fn poll_produce( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + _store: StoreContextMut, + finish: bool, + ) -> Poll>> { + match Pin::new(&mut self.0).poll(cx) { + Poll::Ready(Ok(_err)) => Poll::Ready(Ok(Some(Err(())))), + Poll::Ready(Err(..)) => Poll::Ready(Ok(Some(Ok(())))), + Poll::Pending if finish => Poll::Ready(Ok(None)), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/crates/wasi-tls/src/p3/host/server.rs b/crates/wasi-tls/src/p3/host/server.rs new file mode 100644 index 000000000000..94ec6481a477 --- /dev/null +++ b/crates/wasi-tls/src/p3/host/server.rs @@ -0,0 +1,284 @@ +#![expect(unused, reason = "WIP")] + +use super::{PlaintextConsumer, PlaintextProducer, ResultProducer, mk_delete, mk_get, mk_push}; +use crate::p3::bindings::tls::server::{ + Handshake, Host, HostHandshake, HostHandshakeWithStore, HostWithStore, +}; +use crate::p3::bindings::tls::types::Certificate; +use crate::p3::{TlsStream, TlsStreamServerArc, WasiTls, WasiTlsCtxView}; +use anyhow::{Context as _, anyhow}; +use core::mem; +use core::pin::Pin; +use core::task::{Context, Poll}; +use rustls::server::ResolvesServerCert; +use std::sync::{Arc, Mutex}; +use tokio::sync::oneshot; +use wasmtime::StoreContextMut; +use wasmtime::component::{ + Access, Accessor, Destination, FutureReader, Resource, Source, StreamConsumer, StreamProducer, + StreamReader, StreamResult, +}; + +mk_delete!(Handshake, delete_handshake, "server handshake"); +mk_get!(Handshake, get_handshake, "server handshake"); +mk_push!(Handshake, push_handshake, "server handshake"); + +enum CiphertextConsumer { + Pending { + acceptor: rustls::server::Acceptor, + tx: oneshot::Sender< + Result< + ( + rustls::server::Accepted, + oneshot::Sender, + ), + rustls::Error, + >, + >, + }, + Accepted(oneshot::Receiver), + Active(super::CiphertextConsumer), + Corrupted, +} + +impl StreamConsumer for CiphertextConsumer { + type Item = u8; + + fn poll_consume( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + store: StoreContextMut, + src: Source, + finish: bool, + ) -> Poll> { + let this = self.get_mut(); + match mem::replace(this, Self::Corrupted) { + Self::Pending { mut acceptor, tx } => { + let mut src = src.as_direct(store); + if src.remaining().is_empty() { + return Poll::Ready(Ok(StreamResult::Completed)); + } + acceptor.read_tls(&mut src)?; + match acceptor.accept() { + Ok(None) => { + *this = Self::Pending { acceptor, tx }; + Poll::Ready(Ok(StreamResult::Completed)) + } + Ok(Some(accepted)) => { + let (stream_tx, stream_rx) = oneshot::channel(); + _ = tx.send(Ok((accepted, stream_tx))); + *this = Self::Accepted(stream_rx); + Poll::Ready(Ok(StreamResult::Completed)) + } + Err(err) => { + _ = tx.send(Err(err)); + Poll::Ready(Ok(StreamResult::Dropped)) + } + } + } + Self::Accepted(mut rx) => match Pin::new(&mut rx).poll(cx) { + Poll::Ready(Ok(stream)) => { + *this = Self::Active(super::CiphertextConsumer(stream)); + Poll::Ready(Ok(StreamResult::Completed)) + } + Poll::Ready(Err(..)) => Poll::Ready(Ok(StreamResult::Dropped)), + Poll::Pending if finish => { + *this = Self::Accepted(rx); + Poll::Ready(Ok(StreamResult::Cancelled)) + } + Poll::Pending => { + *this = Self::Accepted(rx); + Poll::Ready(Ok(StreamResult::Cancelled)) + } + }, + Self::Active(ref mut conn) => Pin::new(conn).poll_consume(cx, store, src, finish), + Self::Corrupted => Poll::Ready(Err(anyhow!("corrupted stream consumer state"))), + } + } +} + +enum CiphertextProducer { + Pending(oneshot::Receiver), + Active(super::CiphertextProducer), + Corrupted, +} + +impl StreamProducer for CiphertextProducer { + type Item = u8; + type Buffer = Option; // unused + + fn poll_produce<'a>( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + store: StoreContextMut<'a, D>, + dst: Destination<'a, Self::Item, Self::Buffer>, + finish: bool, + ) -> Poll> { + let this = self.get_mut(); + match mem::replace(this, Self::Corrupted) { + Self::Pending(mut rx) => match Pin::new(&mut rx).poll(cx) { + Poll::Ready(Ok(stream)) => { + *this = Self::Active(super::CiphertextProducer(stream)); + Poll::Ready(Ok(StreamResult::Completed)) + } + Poll::Ready(Err(..)) => Poll::Ready(Ok(StreamResult::Dropped)), + Poll::Pending if finish => { + *this = Self::Pending(rx); + Poll::Ready(Ok(StreamResult::Cancelled)) + } + Poll::Pending => { + *this = Self::Pending(rx); + Poll::Ready(Ok(StreamResult::Cancelled)) + } + }, + Self::Active(ref mut conn) => Pin::new(conn).poll_produce(cx, store, dst, finish), + Self::Corrupted => Poll::Ready(Err(anyhow!("corrupted stream producer state"))), + } + } +} + +#[derive(Debug)] +struct CertificateResolver; + +impl ResolvesServerCert for CertificateResolver { + fn resolve( + &self, + hello: rustls::server::ClientHello, + ) -> Option> { + // TODO: Implement + None + } +} + +impl Host for WasiTlsCtxView<'_> {} + +impl HostWithStore for WasiTls { + async fn accept( + store: &Accessor, + incoming: StreamReader, + ) -> wasmtime::Result, Resource), ()>> { + let (accept_tx, accept_rx) = oneshot::channel(); + store.with(|store| { + incoming.pipe( + store, + CiphertextConsumer::Pending { + acceptor: rustls::server::Acceptor::default(), + tx: accept_tx, + }, + ); + }); + let (accepted, consumer_tx) = match accept_rx + .await + .context("oneshot sender dropped unexpectedly")? + { + Ok((accepted, consumer_tx)) => (accepted, consumer_tx), + Err(_err) => return Ok(Err(())), + }; + let (producer_tx, producer_rx) = oneshot::channel(); + store.with(|mut store| { + let handshake = push_handshake( + store.get().table, + Handshake { + accepted, + consumer_tx, + producer_tx, + }, + )?; + Ok(Ok(( + StreamReader::new(&mut store, CiphertextProducer::Pending(producer_rx)), + handshake, + ))) + }) + } +} + +impl HostHandshake for WasiTlsCtxView<'_> { + fn set_server_certificate( + &mut self, + handshake: Resource, + cert: Resource, + ) -> wasmtime::Result<()> { + todo!() + } + + fn get_client_certificate( + &mut self, + handshake: Resource, + ) -> wasmtime::Result, ()>>> { + todo!() + } + + fn get_server_name( + &mut self, + handshake: Resource, + ) -> wasmtime::Result> { + let handshake = get_handshake(&self.table, &handshake)?; + let hello = handshake.accepted.client_hello(); + let server_name = hello.server_name().map(Into::into); + Ok(server_name) + } + + fn get_alpn_ids( + &mut self, + handshake: Resource, + ) -> wasmtime::Result>>> { + let handshake = get_handshake(&self.table, &handshake)?; + let hello = handshake.accepted.client_hello(); + let alpn = hello.alpn().map(|alpn| alpn.map(Into::into).collect()); + Ok(alpn) + } + + fn get_cipher_suites(&mut self, handshake: Resource) -> wasmtime::Result> { + let handshake = get_handshake(&self.table, &handshake)?; + let hello = handshake.accepted.client_hello(); + let cipher_suites = hello + .cipher_suites() + .into_iter() + .map(rustls::CipherSuite::get_u16) + .collect(); + Ok(cipher_suites) + } + + fn set_cipher_suite( + &mut self, + handshake: Resource, + cipher_suite: u16, + ) -> wasmtime::Result<()> { + todo!() + } + + fn drop(&mut self, handshake: Resource) -> wasmtime::Result<()> { + delete_handshake(&mut self.table, handshake)?; + Ok(()) + } +} + +impl HostHandshakeWithStore for WasiTls { + fn finish( + mut store: Access, + handshake: Resource, + data: StreamReader, + ) -> wasmtime::Result<(StreamReader, FutureReader>)> { + let Handshake { + accepted, + consumer_tx, + producer_tx, + } = delete_handshake(&mut store.get().table, handshake)?; + // TODO: configure + let config = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(Arc::new(CertificateResolver)); + let conn = accepted + .into_connection(Arc::from(config)) + .context("failed to construct rustls server connection")?; + let (error_tx, error_rx) = oneshot::channel(); + let stream = Arc::new(Mutex::new(TlsStream::new(conn, error_tx))); + data.pipe(&mut store, PlaintextConsumer(Arc::clone(&stream))); + _ = consumer_tx.send(Arc::clone(&stream)); + _ = producer_tx.send(Arc::clone(&stream)); + Ok(( + StreamReader::new(&mut store, PlaintextProducer(stream)), + FutureReader::new(&mut store, ResultProducer(error_rx)), + )) + } +} diff --git a/crates/wasi-tls/src/p3/host/types.rs b/crates/wasi-tls/src/p3/host/types.rs new file mode 100644 index 000000000000..2c71ef0fbb5e --- /dev/null +++ b/crates/wasi-tls/src/p3/host/types.rs @@ -0,0 +1,15 @@ +use super::mk_delete; +use crate::p3::WasiTlsCtxView; +use crate::p3::bindings::tls::types::{Certificate, Host, HostCertificate}; +use wasmtime::component::Resource; + +mk_delete!(Certificate, delete_certificate, "certificate"); + +impl Host for WasiTlsCtxView<'_> {} + +impl HostCertificate for WasiTlsCtxView<'_> { + fn drop(&mut self, cert: Resource) -> wasmtime::Result<()> { + delete_certificate(&mut self.table, cert)?; + Ok(()) + } +} diff --git a/crates/wasi-tls/src/p3/mod.rs b/crates/wasi-tls/src/p3/mod.rs new file mode 100644 index 000000000000..e3be4924a4b4 --- /dev/null +++ b/crates/wasi-tls/src/p3/mod.rs @@ -0,0 +1,180 @@ +//! Experimental, unstable and incomplete implementation of wasip3 version of `wasi:tls`. +//! +//! This module is under heavy development. +//! It is not compliant with semver and is not ready +//! for production use. +//! +//! Bug and security fixes limited to wasip3 will not be given patch releases. +//! +//! Documentation of this module may be incorrect or out-of-sync with the implementation. + +pub mod bindings; +mod host; + +use core::task::Waker; +use std::sync::{Arc, Mutex}; + +use bindings::tls::{client, server, types}; +use rustls::pki_types::ServerName; +use tokio::sync::oneshot; +use wasmtime::component::{HasData, Linker, ResourceTable}; + +/// The type for which this crate implements the `wasi:tls` interfaces. +pub struct WasiTls; + +impl HasData for WasiTls { + type Data<'a> = WasiTlsCtxView<'a>; +} + +/// A trait which provides internal WASI TLS state. +pub trait WasiTlsCtx: Send {} + +/// Default implementation of [WasiTlsCtx]. +#[derive(Clone, Default)] +pub struct DefaultWasiTlsCtx; + +impl WasiTlsCtx for DefaultWasiTlsCtx {} + +/// View into [WasiTlsCtx] implementation and [ResourceTable]. +pub struct WasiTlsCtxView<'a> { + /// Mutable reference to the WASI TLS context. + pub ctx: &'a mut dyn WasiTlsCtx, + + /// Mutable reference to table used to manage resources. + pub table: &'a mut ResourceTable, +} + +/// A trait which provides internal WASI TLS state. +pub trait WasiTlsView: Send { + /// Return a [WasiTlsCtxView] from mutable reference to self. + fn tls(&mut self) -> WasiTlsCtxView<'_>; +} + +/// Add all interfaces from this module into the `linker` provided. +/// +/// This function will add all interfaces implemented by this module to the +/// [`Linker`], which corresponds to the `wasi:tls/imports` world supported by +/// this module. +/// +/// # Example +/// +/// ``` +/// use wasmtime::{Engine, Result, Store, Config}; +/// use wasmtime::component::{Linker, ResourceTable}; +/// use wasmtime_wasi_tls::p3::{DefaultWasiTlsCtx, WasiTlsCtxView, WasiTlsView}; +/// +/// fn main() -> Result<()> { +/// let mut config = Config::new(); +/// config.async_support(true); +/// config.wasm_component_model_async(true); +/// let engine = Engine::new(&config)?; +/// +/// let mut linker = Linker::::new(&engine); +/// wasmtime_wasi_tls::p3::add_to_linker(&mut linker)?; +/// // ... add any further functionality to `linker` if desired ... +/// +/// let mut store = Store::new( +/// &engine, +/// MyState::default(), +/// ); +/// +/// // ... use `linker` to instantiate within `store` ... +/// +/// Ok(()) +/// } +/// +/// #[derive(Default)] +/// struct MyState { +/// tls: DefaultWasiTlsCtx, +/// table: ResourceTable, +/// } +/// +/// impl WasiTlsView for MyState { +/// fn tls(&mut self) -> WasiTlsCtxView<'_> { +/// WasiTlsCtxView { +/// ctx: &mut self.tls, +/// table: &mut self.table, +/// } +/// } +/// } +/// ``` +pub fn add_to_linker(linker: &mut Linker) -> wasmtime::Result<()> +where + T: WasiTlsView + 'static, +{ + client::add_to_linker::<_, WasiTls>(linker, T::tls)?; + server::add_to_linker::<_, WasiTls>(linker, T::tls)?; + types::add_to_linker::<_, WasiTls>(linker, T::tls)?; + Ok(()) +} + +/// Client hello +#[derive(Clone, Default, Eq, PartialEq, Hash)] +pub struct ClientHello { + /// Server name indicator. + pub server_name: Option>, + /// ALPN IDs + pub alpn_ids: Option>>, + /// Cipher suites + pub cipher_suites: Vec, +} + +/// Server hello +#[derive(Clone, Eq, PartialEq, Hash)] +pub struct ServerHello { + /// Cipher suite + pub cipher_suite: u16, +} + +impl ServerHello { + /// Constructs a new server hello message + pub fn new(cipher_suite: u16) -> Self { + Self { cipher_suite } + } +} + +type TlsStreamArc = Arc>>; +type TlsStreamClientArc = TlsStreamArc; +type TlsStreamServerArc = TlsStreamArc; + +/// Client handshake +pub struct ClientHandshake { + stream: TlsStreamClientArc, + error_rx: oneshot::Receiver, +} + +/// Server handshake +pub struct ServerHandshake { + accepted: rustls::server::Accepted, + consumer_tx: oneshot::Sender, + producer_tx: oneshot::Sender, +} + +/// Certificate +pub struct Certificate; + +struct TlsStream { + conn: T, + error_tx: Option>, + close_notify: bool, + read_tls: Option, + ciphertext_consumer: Option, + ciphertext_producer: Option, + plaintext_consumer: Option, + plaintext_producer: Option, +} + +impl TlsStream { + fn new(conn: T, error_tx: oneshot::Sender) -> Self { + Self { + conn, + error_tx: Some(error_tx), + close_notify: false, + read_tls: None, + plaintext_producer: None, + plaintext_consumer: None, + ciphertext_producer: None, + ciphertext_consumer: None, + } + } +} diff --git a/crates/wasi-tls/src/p3/wit/deps/tls/client.wit b/crates/wasi-tls/src/p3/wit/deps/tls/client.wit new file mode 100644 index 000000000000..2aaaf695fa54 --- /dev/null +++ b/crates/wasi-tls/src/p3/wit/deps/tls/client.wit @@ -0,0 +1,38 @@ +interface client { + use types.{certificate}; + + resource hello { + /// Constructs a new ClientHello message. + constructor(); + + /// Sets the server name indicator. + set-server-name: func(server-name: string) -> result; + + /// Sets the ALPN IDs advertised by the client. + set-alpn-ids: func(alpn-ids: list>); + + /// Sets a list of the symmetric cipher options supported by + /// the client, specifically the record protection algorithm + /// (including secret key length) and a hash to be used with HKDF, in + /// descending order of client preference. + /// + /// If this list is empty, the implementation must use a reasonable default. + set-cipher-suites: func(cipher-suites: list); + } + + resource handshake { + set-client-certificate: func(cert: certificate); + + get-server-certificate: func() -> option; + + /// Gets the single cipher suite selected by the server from + /// the list in ClientHello.cipher_suites. + get-cipher-suite: func() -> u16; + + /// Closing the `data` stream will trigger `close_notify`. + finish: static func(this: handshake, data: stream) -> tuple, future>; + } + + /// Initiate the client TLS handshake + connect: func(hello: hello, incoming: stream) -> tuple, future>>; +} diff --git a/crates/wasi-tls/src/p3/wit/deps/tls/server.wit b/crates/wasi-tls/src/p3/wit/deps/tls/server.wit new file mode 100644 index 000000000000..6d0f23d33b13 --- /dev/null +++ b/crates/wasi-tls/src/p3/wit/deps/tls/server.wit @@ -0,0 +1,36 @@ +interface server { + use types.{certificate}; + + resource handshake { + set-server-certificate: func(cert: certificate); + + get-client-certificate: func() -> future>; + + /// Gets the server name indicator. + /// Returns `none` if the client did not supply a SNI. + get-server-name: func() -> option; + + /// Gets the ALPN IDs advertised by the client. + /// Returns `none` if the client did not include an ALPN extension. + get-alpn-ids: func() -> option>>; + + /// Gets a list of the symmetric cipher options supported by + /// the client, specifically the record protection algorithm + /// (including secret key length) and a hash to be used with HKDF, in + /// descending order of client preference. + get-cipher-suites: func() -> list; + + /// Selects the cipher-suite from + /// the list returned by `get-cipher-suites` + /// + /// If this is not called before `finish`, implementation + /// will select appropriate cipher suite. + set-cipher-suite: func(cipher-suite: u16); + + /// Closing the `data` stream will trigger `close_notify`. + finish: static func(this: handshake, data: stream) -> tuple, future>; + } + + /// Accept the client TLS handshake + accept: async func(incoming: stream) -> result, handshake>>; +} diff --git a/crates/wasi-tls/src/p3/wit/deps/tls/types.wit b/crates/wasi-tls/src/p3/wit/deps/tls/types.wit new file mode 100644 index 000000000000..a0bcecae7da7 --- /dev/null +++ b/crates/wasi-tls/src/p3/wit/deps/tls/types.wit @@ -0,0 +1,5 @@ +interface types { + resource certificate { + // TODO: define + } +} diff --git a/crates/wasi-tls/src/p3/wit/deps/tls/world.wit b/crates/wasi-tls/src/p3/wit/deps/tls/world.wit new file mode 100644 index 000000000000..f605ce358457 --- /dev/null +++ b/crates/wasi-tls/src/p3/wit/deps/tls/world.wit @@ -0,0 +1,7 @@ +package wasi:tls@0.3.0-draft; + +world imports { + import client; + import server; + import types; +} diff --git a/crates/wasi-tls/src/p3/wit/world.wit b/crates/wasi-tls/src/p3/wit/world.wit new file mode 100644 index 000000000000..51e0e7e8cc58 --- /dev/null +++ b/crates/wasi-tls/src/p3/wit/world.wit @@ -0,0 +1,2 @@ +// We actually don't use this; it's just to let bindgen! find the corresponding world in wit/deps. +package wasmtime:wasi-tls; diff --git a/crates/wasi-tls/tests/main.rs b/crates/wasi-tls/tests/p2.rs similarity index 90% rename from crates/wasi-tls/tests/main.rs rename to crates/wasi-tls/tests/p2.rs index d80797ff3871..f248c9bc466b 100644 --- a/crates/wasi-tls/tests/main.rs +++ b/crates/wasi-tls/tests/p2.rs @@ -62,9 +62,9 @@ macro_rules! assert_test_exists { }; } -test_programs_artifacts::foreach_tls!(assert_test_exists); +test_programs_artifacts::foreach_p2_tls!(assert_test_exists); #[tokio::test(flavor = "multi_thread")] -async fn tls_sample_application() -> Result<()> { - run_test(test_programs_artifacts::TLS_SAMPLE_APPLICATION_COMPONENT).await +async fn p2_tls_sample_application() -> Result<()> { + run_test(test_programs_artifacts::P2_TLS_SAMPLE_APPLICATION_COMPONENT).await } diff --git a/crates/wasi-tls/tests/p3.rs b/crates/wasi-tls/tests/p3.rs new file mode 100644 index 000000000000..c42ce1acca24 --- /dev/null +++ b/crates/wasi-tls/tests/p3.rs @@ -0,0 +1,81 @@ +use anyhow::{Context as _, Result, anyhow}; +use wasmtime::Store; +use wasmtime::component::{Component, Linker, ResourceTable}; +use wasmtime_wasi::p3::bindings::Command; +use wasmtime_wasi::{WasiCtx, WasiCtxView, WasiView}; +use wasmtime_wasi_tls::p3::{DefaultWasiTlsCtx, WasiTlsCtxView, WasiTlsView}; + +struct Ctx { + table: ResourceTable, + wasi_ctx: WasiCtx, + wasi_tls_ctx: DefaultWasiTlsCtx, +} + +impl WasiView for Ctx { + fn ctx(&mut self) -> WasiCtxView<'_> { + WasiCtxView { + ctx: &mut self.wasi_ctx, + table: &mut self.table, + } + } +} + +impl WasiTlsView for Ctx { + fn tls(&mut self) -> WasiTlsCtxView<'_> { + WasiTlsCtxView { + ctx: &mut self.wasi_tls_ctx, + table: &mut self.table, + } + } +} + +async fn run_test(path: &str) -> Result<()> { + let ctx = Ctx { + table: ResourceTable::new(), + wasi_ctx: WasiCtx::builder() + .inherit_stdout() + .inherit_stderr() + .inherit_network() + .allow_ip_name_lookup(true) + .build(), + wasi_tls_ctx: DefaultWasiTlsCtx, + }; + + let engine = test_programs_artifacts::engine(|config| { + config.async_support(true); + config.wasm_component_model_async(true); + }); + let mut store = Store::new(&engine, ctx); + + let mut linker = Linker::new(&engine); + // TODO: Remove once test components are not built for `wasm32-wasip1` + wasmtime_wasi::p2::add_to_linker_async(&mut linker) + .context("failed to link `wasi:cli@0.2.x`")?; + wasmtime_wasi::p3::add_to_linker(&mut linker).context("failed to link `wasi:cli@0.3.x`")?; + wasmtime_wasi_tls::p3::add_to_linker(&mut linker)?; + + let component = Component::from_file(&engine, path)?; + let command = Command::instantiate_async(&mut store, &component, &linker) + .await + .context("failed to instantiate `wasi:cli/command`")?; + store + .run_concurrent(async move |store| command.wasi_cli_run().call_run(store).await) + .await + .context("failed to call `wasi:cli/run#run`")? + .context("guest trapped")? + .map_err(|()| anyhow!("`wasi:cli/run#run` failed")) +} + +macro_rules! assert_test_exists { + ($name:ident) => { + #[expect(unused_imports, reason = "just here to assert it exists")] + use self::$name as _; + }; +} + +test_programs_artifacts::foreach_p3_tls!(assert_test_exists); + +#[tokio::test(flavor = "multi_thread")] +async fn p3_tls_sample_application() -> Result<()> { + run_test(test_programs_artifacts::P3_TLS_SAMPLE_APPLICATION_COMPONENT).await +} diff --git a/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs b/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs index 9210914a182c..daee20335ed8 100644 --- a/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs +++ b/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs @@ -336,6 +336,20 @@ pub struct DirectDestination<'a, D: 'static> { store: StoreContextMut<'a, D>, } +impl std::io::Write for DirectDestination<'_, D> { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let rem = self.remaining(); + let n = rem.len().min(buf.len()); + rem[..n].copy_from_slice(&buf[..n]); + self.mark_written(n); + Ok(n) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + impl DirectDestination<'_, D> { /// Provide direct access to the writer's buffer. pub fn remaining(&mut self) -> &mut [u8] { @@ -836,6 +850,16 @@ pub struct DirectSource<'a, D: 'static> { store: StoreContextMut<'a, D>, } +impl std::io::Read for DirectSource<'_, D> { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let rem = self.remaining(); + let n = rem.len().min(buf.len()); + buf[..n].copy_from_slice(&rem[..n]); + self.mark_read(n); + Ok(n) + } +} + impl DirectSource<'_, D> { /// Provide direct access to the writer's buffer. pub fn remaining(&mut self) -> &[u8] {