diff --git a/tlsn/examples/interactive/interactive.rs b/tlsn/examples/interactive/interactive.rs index 3f9c4d1c43..9ecaffa3a3 100644 --- a/tlsn/examples/interactive/interactive.rs +++ b/tlsn/examples/interactive/interactive.rs @@ -1,4 +1,3 @@ -use futures::AsyncWriteExt; use http_body_util::Empty; use hyper::{body::Bytes, Request, StatusCode, Uri}; use hyper_util::rt::TokioIo; @@ -49,6 +48,8 @@ async fn prover( let server_port = uri.port_u16().unwrap_or(443); // Create prover and connect to verifier. + // + // Perform the setup phase with the verifier. let prover = Prover::new( ProverConfig::builder() .id(id) @@ -64,9 +65,18 @@ async fn prover( let tls_client_socket = tokio::net::TcpStream::connect((server_domain, server_port)) .await .unwrap(); + + // Pass server connection into the prover. let (mpc_tls_connection, prover_fut) = prover.connect(tls_client_socket.compat()).await.unwrap(); + + // Grab a controller for the Prover so we can enable deferred decryption. + let ctrl = prover_fut.control(); + + // Wrap the connection in a TokioIo compatibility layer to use it with hyper. let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat()); + + // Spawn the Prover to run in the background. let prover_task = tokio::spawn(prover_fut); // MPC-TLS Handshake. @@ -75,7 +85,12 @@ async fn prover( .await .unwrap(); - let connection_task = tokio::spawn(connection.without_shutdown()); + // Spawn the connection to run in the background. + tokio::spawn(connection); + + // Enable deferred decryption. This speeds up the proving time, but doesn't + // let us see the decrypted data until after the connection is closed. + ctrl.defer_decryption().await.unwrap(); // MPC-TLS: Send Request and wait for Response. let request = Request::builder() @@ -90,10 +105,6 @@ async fn prover( assert!(response.status() == StatusCode::OK); - // Close TLS Connection. - let tls_connection = connection_task.await.unwrap().unwrap().io.into_inner(); - tls_connection.compat().close().await.unwrap(); - // Create proof for the Verifier. let mut prover = prover_task.await.unwrap().unwrap().start_prove(); redact_and_reveal_received_data(&mut prover); @@ -128,6 +139,7 @@ async fn verifier( response .find("BEGIN PUBLIC KEY") .expect("Expected valid public key in JSON response"); + // Check Session info: server name. assert_eq!(session_info.server_name.as_str(), SERVER_DOMAIN);