From 3057c062a46d44116fd27faeb0a4903fb5241e66 Mon Sep 17 00:00:00 2001 From: Manish Date: Fri, 21 Nov 2025 15:17:52 +0530 Subject: [PATCH 1/2] feat: add graceful shutdown with SIGTERM handling and GOAWAY message support --- moq-clock-ietf/src/main.rs | 6 +- moq-pub/src/main.rs | 2 +- moq-relay-ietf/src/relay.rs | 83 +++++++++++++++++++++++-- moq-relay-ietf/src/remote.rs | 18 ++++-- moq-relay-ietf/src/session.rs | 13 ++-- moq-sub/src/main.rs | 2 +- moq-transport/src/message/mod.rs | 6 ++ moq-transport/src/session/mod.rs | 51 ++++++++++++++- moq-transport/src/session/subscriber.rs | 22 +++++++ 9 files changed, 182 insertions(+), 21 deletions(-) diff --git a/moq-clock-ietf/src/main.rs b/moq-clock-ietf/src/main.rs index 199116cb..afb525bc 100644 --- a/moq-clock-ietf/src/main.rs +++ b/moq-clock-ietf/src/main.rs @@ -64,7 +64,7 @@ async fn main() -> anyhow::Result<()> { let clock_publisher = clock::Publisher::new_datagram(track_writer.datagrams()?); tokio::select! { - res = session.run() => res.context("session error")?, + res = session.run(None) => res.context("session error")?, res = clock_publisher.run() => res.context("clock error")?, res = publisher.announce(tracks_reader) => res.context("failed to serve tracks")?, } @@ -80,7 +80,7 @@ async fn main() -> anyhow::Result<()> { let clock_publisher = clock::Publisher::new(track_writer.subgroups()?); tokio::select! { - res = session.run() => res.context("session error")?, + res = session.run(None) => res.context("session error")?, res = clock_publisher.run() => res.context("clock error")?, res = publisher.announce(tracks_reader) => res.context("failed to serve tracks")?, } @@ -104,7 +104,7 @@ async fn main() -> anyhow::Result<()> { let clock_subscriber = clock::Subscriber::new(track_reader); tokio::select! { - res = session.run() => res.context("session error")?, + res = session.run(None) => res.context("session error")?, res = clock_subscriber.run() => res.context("clock error")?, res = subscriber.subscribe(track_writer) => res.context("failed to subscribe to track")?, } diff --git a/moq-pub/src/main.rs b/moq-pub/src/main.rs index 7e71764c..74eb66c6 100644 --- a/moq-pub/src/main.rs +++ b/moq-pub/src/main.rs @@ -76,7 +76,7 @@ async fn main() -> anyhow::Result<()> { .context("failed to create MoQ Transport publisher")?; tokio::select! { - res = session.run() => res.context("session error")?, + res = session.run(None) => res.context("session error")?, res = run_media(media) => { res.context("media error")? }, diff --git a/moq-relay-ietf/src/relay.rs b/moq-relay-ietf/src/relay.rs index 41a68237..0f2c1181 100644 --- a/moq-relay-ietf/src/relay.rs +++ b/moq-relay-ietf/src/relay.rs @@ -4,6 +4,8 @@ use anyhow::Context; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use moq_native_ietf::quic; +use moq_transport::session::SessionMigration; +use tokio::sync::broadcast; use url::Url; use crate::{Api, Consumer, Locals, Producer, Remotes, RemotesConsumer, RemotesProducer, Session}; @@ -96,9 +98,58 @@ impl Relay { pub async fn run(self) -> anyhow::Result<()> { let mut tasks = FuturesUnordered::new(); + // Setup SIGTERM handler and broadcast channel + #[cfg(unix)] + let mut signal_term = + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?; + let mut signal_int = + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?; + + let (signal_tx, signal_rx) = broadcast::channel::(16); + + // Get server address early for the shutdown signal + let server_addr = self + .quic + .server + .as_ref() + .context("missing TLS certificate")? + .local_addr()?; + let shutdown_uri = format!("https://{}", server_addr); + + // Spawn task to listen for SIGTERM and broadcast shutdown + let signal_tx_clone = signal_tx.clone(); + tasks.push( + async move { + log::info!("Listening for SIGTERM"); + #[cfg(unix)] + { + tokio::select! { + _ = signal_term.recv() => { + log::info!("Received SIGTERM"); + } + _ = signal_int.recv() => { + log::info!("Received SIGINT"); + } + } + log::info!("broadcasting shutdown to all sessions"); + + if let Err(e) = signal_tx.send(SessionMigration { uri: shutdown_uri }) { + log::error!("failed to broadcast shutdown: {}", e); + } + } + #[cfg(not(unix))] + { + std::future::pending::<()>().await; + } + Ok(()) + } + .boxed(), + ); + // Start the remotes producer task, if any let remotes = self.remotes.map(|(producer, consumer)| { - tasks.push(producer.run().boxed()); + let signal_rx = signal_rx.resubscribe(); + tasks.push(producer.run(signal_rx).boxed()); consumer }); @@ -133,7 +184,15 @@ impl Relay { let forward_producer = session.producer.clone(); - tasks.push(async move { session.run().await.context("forwarding failed") }.boxed()); + tasks.push( + async move { + session + .run(signal_tx_clone.subscribe()) + .await + .context("forwarding failed") + } + .boxed(), + ); forward_producer } else { @@ -143,6 +202,7 @@ impl Relay { // Start the QUIC server loop let mut server = self.quic.server.context("missing TLS certificate")?; log::info!("listening on {}", server.local_addr()?); + let mut cloned_signal_rx = signal_rx.resubscribe(); loop { tokio::select! { @@ -158,10 +218,10 @@ impl Relay { let remotes = remotes.clone(); let forward = forward_producer.clone(); let api = self.api.clone(); + let session_signal_rx = signal_rx.resubscribe(); // Spawn a new task to handle the connection tasks.push(async move { - // Create the MoQ session over the connection (setup handshake etc) let (session, publisher, subscriber) = match moq_transport::session::Session::accept(conn, mlog_path).await { Ok(session) => session, @@ -178,7 +238,7 @@ impl Relay { consumer: subscriber.map(|subscriber| Consumer::new(subscriber, locals, api, forward)), }; - if let Err(err) = session.run().await { + if let Err(err) = session.run(session_signal_rx).await { log::warn!("failed to run MoQ session: {}", err); } @@ -186,6 +246,21 @@ impl Relay { }.boxed()); }, res = tasks.next(), if !tasks.is_empty() => res.unwrap()?, + _ = cloned_signal_rx.recv() => { + log::info!("received shutdown signal, shutting down. Active tasks: {}", tasks.len()); + // set a timeout for waiting for tasks to be empty + // FIXME(itzmanish): make this configurable and revisit + let timeout = tokio::time::timeout(tokio::time::Duration::from_secs(20), async move { + while !tasks.is_empty() { + // sleep 500ms before checking again + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + } + }); + if let Err(e) = timeout.await { + log::warn!("timed out waiting for tasks to be empty: {}", e); + } + break Ok(()); + } } } } diff --git a/moq-relay-ietf/src/remote.rs b/moq-relay-ietf/src/remote.rs index 1b412d0b..636115dc 100644 --- a/moq-relay-ietf/src/remote.rs +++ b/moq-relay-ietf/src/remote.rs @@ -12,7 +12,9 @@ use futures::StreamExt; use moq_native_ietf::quic; use moq_transport::coding::TrackNamespace; use moq_transport::serve::{Track, TrackReader, TrackWriter}; +use moq_transport::session::SessionMigration; use moq_transport::watch::State; +use tokio::sync::broadcast; use url::Url; use crate::Api; @@ -72,13 +74,18 @@ impl RemotesProducer { } /// Run the remotes producer to serve remote requests. - pub async fn run(mut self) -> anyhow::Result<()> { + pub async fn run( + mut self, + signal_rx: broadcast::Receiver, + ) -> anyhow::Result<()> { let mut tasks = FuturesUnordered::new(); loop { tokio::select! { Some(mut remote) = self.next() => { let url = remote.url.clone(); + // Each remote task needs its own receiver + let remote_signal_rx = signal_rx.resubscribe(); // Spawn a task to serve the remote tasks.push(async move { @@ -87,7 +94,7 @@ impl RemotesProducer { log::warn!("serving remote: {:?}", info); // Run the remote producer - if let Err(err) = remote.run().await { + if let Err(err) = remote.run(remote_signal_rx).await { log::warn!("failed serving remote: {:?}, error: {}", info, err); } @@ -225,13 +232,16 @@ impl RemoteProducer { Self { info, state } } - pub async fn run(&mut self) -> anyhow::Result<()> { + pub async fn run( + &mut self, + signal_rx: broadcast::Receiver, + ) -> anyhow::Result<()> { // TODO reuse QUIC and MoQ sessions let (session, _quic_client_initial_cid) = self.quic.connect(&self.url).await?; let (session, subscriber) = moq_transport::session::Subscriber::connect(session).await?; // Run the session - let mut session = session.run().boxed(); + let mut session = session.run(Some(signal_rx)).boxed(); let mut tasks = FuturesUnordered::new(); let mut done = None; diff --git a/moq-relay-ietf/src/session.rs b/moq-relay-ietf/src/session.rs index b55748b8..3efd1d4a 100644 --- a/moq-relay-ietf/src/session.rs +++ b/moq-relay-ietf/src/session.rs @@ -1,7 +1,7 @@ -use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; -use moq_transport::session::SessionError; - use crate::{Consumer, Producer}; +use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; +use moq_transport::session::{SessionError, SessionMigration}; +use tokio::sync::broadcast; pub struct Session { pub session: moq_transport::session::Session, @@ -11,9 +11,12 @@ pub struct Session { impl Session { /// Run the session, producer, and consumer as necessary. - pub async fn run(self) -> Result<(), SessionError> { + pub async fn run( + self, + signal_rx: broadcast::Receiver, + ) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); - tasks.push(self.session.run().boxed()); + tasks.push(self.session.run(Some(signal_rx)).boxed()); if let Some(producer) = self.producer { tasks.push(producer.run().boxed()); diff --git a/moq-sub/src/main.rs b/moq-sub/src/main.rs index e38154f2..5c163948 100644 --- a/moq-sub/src/main.rs +++ b/moq-sub/src/main.rs @@ -45,7 +45,7 @@ async fn main() -> anyhow::Result<()> { let mut media = Media::new(subscriber, tracks, out, config.catalog).await?; tokio::select! { - res = session.run() => res.context("session error")?, + res = session.run(None) => res.context("session error")?, res = media.run() => res.context("media error")?, } diff --git a/moq-transport/src/message/mod.rs b/moq-transport/src/message/mod.rs index 267368c7..6d99dd19 100644 --- a/moq-transport/src/message/mod.rs +++ b/moq-transport/src/message/mod.rs @@ -222,3 +222,9 @@ message_types! { PublishOk = 0x1e, PublishError = 0x1f, } + +pub enum MiscMessage { + GoAway, + MaxRequestId, + RequestsBlocked, +} diff --git a/moq-transport/src/session/mod.rs b/moq-transport/src/session/mod.rs index e5d9efb6..7d6cdbf3 100644 --- a/moq-transport/src/session/mod.rs +++ b/moq-transport/src/session/mod.rs @@ -12,10 +12,12 @@ mod writer; pub use announce::*; pub use announced::*; pub use error::*; +use log::info; pub use publisher::*; pub use subscribe::*; pub use subscribed::*; pub use subscriber::*; +use tokio::sync::broadcast; pub use track_status_requested::*; use reader::*; @@ -24,8 +26,8 @@ use writer::*; use futures::{stream::FuturesUnordered, StreamExt}; use std::sync::{atomic, Arc, Mutex}; -use crate::coding::KeyValuePairs; -use crate::message::Message; +use crate::coding::{KeyValuePairs, SessionUri}; +use crate::message::{GoAway, Message}; use crate::mlog; use crate::watch::Queue; use crate::{message, setup}; @@ -51,6 +53,11 @@ pub struct Session { mlog: Option>>, } +#[derive(Clone, Debug)] +pub struct SessionMigration { + pub uri: String, +} + impl Session { // Helper for determining the largest supported version fn largest_common(a: &[T], b: &[T]) -> Option { @@ -197,7 +204,33 @@ impl Session { /// Run Tasks for the session, including sending of control messages, receiving and processing /// inbound control messages, receiving and processing new inbound uni-directional QUIC streams, /// and receiving and processing QUIC datagrams received - pub async fn run(self) -> Result<(), SessionError> { + pub async fn run( + self, + signal_rx: Option>, + ) -> Result<(), SessionError> { + let mut cloned_outgoing = self.outgoing.clone(); + + // Spawn a task that waits for shutdown signal and pushes GOAWAY + // This runs independently and doesn't affect the main session tasks + if let Some(mut signal_rx) = signal_rx { + tokio::spawn(async move { + if let Ok(info) = signal_rx.recv().await { + log::info!( + "received terminate/interrupt signal, sending GOAWAY: {:#?}", + info + ); + let msg = GoAway { + uri: SessionUri(info.uri), + }; + if let Err(e) = cloned_outgoing.push(Message::GoAway(msg)) { + log::error!("failed to push GOAWAY: {:#?}", e); + } else { + log::info!("GOAWAY message queued successfully"); + } + } + }); + } + tokio::select! { res = Self::run_recv(self.recver, self.publisher, self.subscriber.clone(), self.mlog.clone()) => res, res = Self::run_send(self.sender, self.outgoing, self.mlog.clone()) => res, @@ -340,6 +373,18 @@ impl Session { Err(msg) => msg, }; + let msg = match msg { + Message::GoAway(goaway) => { + info!("Received GOAWAY: {:?}", goaway); + subscriber + .as_mut() + .ok_or(SessionError::RoleViolation)? + .handle_go_away(goaway)?; + continue; + } + _ => msg, + }; + // TODO GOAWAY, MAX_REQUEST_ID, REQUESTS_BLOCKED log::warn!("Unimplemented message type received: {:?}", msg); return Err(SessionError::unimplemented(&format!( diff --git a/moq-transport/src/session/subscriber.rs b/moq-transport/src/session/subscriber.rs index a9fa2e16..8f0d2677 100644 --- a/moq-transport/src/session/subscriber.rs +++ b/moq-transport/src/session/subscriber.rs @@ -4,6 +4,8 @@ use std::{ sync::{atomic, Arc, Mutex}, }; +use log::info; + use crate::{ coding::{Decode, TrackNamespace}, data, @@ -159,6 +161,26 @@ impl Subscriber { res } + pub(super) fn handle_go_away(&mut self, goaway: message::GoAway) -> Result<(), SessionError> { + info!( + "Received GOAWAY: {:?}, sending unsubscribe for all active subscriptions", + goaway + ); + + // Collect all subscription IDs first to avoid holding the lock while calling remove_subscribe + let ids: Vec = { + let subscribes = self.subscribes.lock().unwrap(); + subscribes.keys().copied().collect() + }; + + // Remove each subscription (this will acquire the lock internally) + for id in ids { + self.remove_subscribe(id); + } + + Ok(()) + } + /// Handle the reception of a PublishNamespace message from the publisher. fn recv_publish_namespace( &mut self, From cc77231b83758f9500db199ffdb535c8144bbcd1 Mon Sep 17 00:00:00 2001 From: Manish Date: Mon, 24 Nov 2025 18:04:40 +0530 Subject: [PATCH 2/2] refactor: remove signal_rx parameter from session.run() and handle migration signals internally --- moq-clock-ietf/src/main.rs | 6 +- moq-pub/src/main.rs | 2 +- moq-relay-ietf/src/main.rs | 7 +++ moq-relay-ietf/src/relay.rs | 76 ++++++++++++++----------- moq-relay-ietf/src/remote.rs | 16 ++---- moq-relay-ietf/src/session.rs | 7 +-- moq-sub/src/main.rs | 2 +- moq-transport/src/session/mod.rs | 29 +++++++--- moq-transport/src/session/publisher.rs | 4 +- moq-transport/src/session/subscribed.rs | 9 +++ moq-transport/src/session/subscriber.rs | 11 ++-- 11 files changed, 97 insertions(+), 72 deletions(-) diff --git a/moq-clock-ietf/src/main.rs b/moq-clock-ietf/src/main.rs index afb525bc..199116cb 100644 --- a/moq-clock-ietf/src/main.rs +++ b/moq-clock-ietf/src/main.rs @@ -64,7 +64,7 @@ async fn main() -> anyhow::Result<()> { let clock_publisher = clock::Publisher::new_datagram(track_writer.datagrams()?); tokio::select! { - res = session.run(None) => res.context("session error")?, + res = session.run() => res.context("session error")?, res = clock_publisher.run() => res.context("clock error")?, res = publisher.announce(tracks_reader) => res.context("failed to serve tracks")?, } @@ -80,7 +80,7 @@ async fn main() -> anyhow::Result<()> { let clock_publisher = clock::Publisher::new(track_writer.subgroups()?); tokio::select! { - res = session.run(None) => res.context("session error")?, + res = session.run() => res.context("session error")?, res = clock_publisher.run() => res.context("clock error")?, res = publisher.announce(tracks_reader) => res.context("failed to serve tracks")?, } @@ -104,7 +104,7 @@ async fn main() -> anyhow::Result<()> { let clock_subscriber = clock::Subscriber::new(track_reader); tokio::select! { - res = session.run(None) => res.context("session error")?, + res = session.run() => res.context("session error")?, res = clock_subscriber.run() => res.context("clock error")?, res = subscriber.subscribe(track_writer) => res.context("failed to subscribe to track")?, } diff --git a/moq-pub/src/main.rs b/moq-pub/src/main.rs index 74eb66c6..7e71764c 100644 --- a/moq-pub/src/main.rs +++ b/moq-pub/src/main.rs @@ -76,7 +76,7 @@ async fn main() -> anyhow::Result<()> { .context("failed to create MoQ Transport publisher")?; tokio::select! { - res = session.run(None) => res.context("session error")?, + res = session.run() => res.context("session error")?, res = run_media(media) => { res.context("media error")? }, diff --git a/moq-relay-ietf/src/main.rs b/moq-relay-ietf/src/main.rs index 995b28ca..256bac0b 100644 --- a/moq-relay-ietf/src/main.rs +++ b/moq-relay-ietf/src/main.rs @@ -68,6 +68,12 @@ pub struct Cli { /// Requires --dev to enable the web server. Only serves files by exact CID - no index. #[arg(long)] pub mlog_serve: bool, + + /// The public URL we advertise to other origins. + /// The provided certificate must be valid for this address. + #[arg(long)] + #[arg(default_value = "https://localhost:4443")] + pub public_url: Option, } #[tokio::main] @@ -105,6 +111,7 @@ async fn main() -> anyhow::Result<()> { // Create a QUIC server for media. let relay = Relay::new(RelayConfig { + public_url: cli.public_url, tls: tls.clone(), bind: cli.bind, qlog_dir: qlog_dir_for_relay, diff --git a/moq-relay-ietf/src/relay.rs b/moq-relay-ietf/src/relay.rs index 0f2c1181..eaa1c235 100644 --- a/moq-relay-ietf/src/relay.rs +++ b/moq-relay-ietf/src/relay.rs @@ -33,10 +33,14 @@ pub struct RelayConfig { /// Our hostname which we advertise to other origins. /// We use QUIC, so the certificate must be valid for this address. pub node: Option, + + /// The public URL we advertise to other origins. + pub public_url: Option, } /// MoQ Relay server. pub struct Relay { + public_url: Option, quic: quic::Endpoint, announce_url: Option, mlog_dir: Option, @@ -85,6 +89,7 @@ impl Relay { }); Ok(Self { + public_url: config.public_url, quic, announce_url: config.announce, mlog_dir: config.mlog_dir, @@ -105,7 +110,7 @@ impl Relay { let mut signal_int = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?; - let (signal_tx, signal_rx) = broadcast::channel::(16); + let (signal_tx, mut signal_rx) = broadcast::channel::(16); // Get server address early for the shutdown signal let server_addr = self @@ -114,7 +119,12 @@ impl Relay { .as_ref() .context("missing TLS certificate")? .local_addr()?; - let shutdown_uri = format!("https://{}", server_addr); + // FIXME(itzmanish): this gives [::]:4433, which is not a valid URL + let shutdown_uri = if let Some(public_url) = &self.public_url { + public_url.clone().into() + } else { + format!("https://{}", server_addr) + }; // Spawn task to listen for SIGTERM and broadcast shutdown let signal_tx_clone = signal_tx.clone(); @@ -148,8 +158,7 @@ impl Relay { // Start the remotes producer task, if any let remotes = self.remotes.map(|(producer, consumer)| { - let signal_rx = signal_rx.resubscribe(); - tasks.push(producer.run(signal_rx).boxed()); + tasks.push(producer.run().boxed()); consumer }); @@ -166,10 +175,13 @@ impl Relay { .context("failed to establish forward connection")?; // Create the MoQ session over the connection - let (session, publisher, subscriber) = - moq_transport::session::Session::connect(session, None) - .await - .context("failed to establish forward session")?; + let (session, publisher, subscriber) = moq_transport::session::Session::connect( + session, + None, + Some(signal_tx_clone.subscribe()), + ) + .await + .context("failed to establish forward session")?; // Create a normal looking session, except we never forward or register announces. let session = Session { @@ -184,15 +196,7 @@ impl Relay { let forward_producer = session.producer.clone(); - tasks.push( - async move { - session - .run(signal_tx_clone.subscribe()) - .await - .context("forwarding failed") - } - .boxed(), - ); + tasks.push(async move { session.run().await.context("forwarding failed") }.boxed()); forward_producer } else { @@ -202,7 +206,6 @@ impl Relay { // Start the QUIC server loop let mut server = self.quic.server.context("missing TLS certificate")?; log::info!("listening on {}", server.local_addr()?); - let mut cloned_signal_rx = signal_rx.resubscribe(); loop { tokio::select! { @@ -218,12 +221,12 @@ impl Relay { let remotes = remotes.clone(); let forward = forward_producer.clone(); let api = self.api.clone(); - let session_signal_rx = signal_rx.resubscribe(); + let session_signal_rx = signal_tx_clone.subscribe(); // Spawn a new task to handle the connection tasks.push(async move { // Create the MoQ session over the connection (setup handshake etc) - let (session, publisher, subscriber) = match moq_transport::session::Session::accept(conn, mlog_path).await { + let (session, publisher, subscriber) = match moq_transport::session::Session::accept(conn, mlog_path, Some(session_signal_rx)).await { Ok(session) => session, Err(err) => { log::warn!("failed to accept MoQ session: {}", err); @@ -238,7 +241,7 @@ impl Relay { consumer: subscriber.map(|subscriber| Consumer::new(subscriber, locals, api, forward)), }; - if let Err(err) = session.run(session_signal_rx).await { + if let Err(err) = session.run().await { log::warn!("failed to run MoQ session: {}", err); } @@ -246,18 +249,27 @@ impl Relay { }.boxed()); }, res = tasks.next(), if !tasks.is_empty() => res.unwrap()?, - _ = cloned_signal_rx.recv() => { - log::info!("received shutdown signal, shutting down. Active tasks: {}", tasks.len()); - // set a timeout for waiting for tasks to be empty - // FIXME(itzmanish): make this configurable and revisit - let timeout = tokio::time::timeout(tokio::time::Duration::from_secs(20), async move { - while !tasks.is_empty() { - // sleep 500ms before checking again - tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + _ = signal_rx.recv() => { + log::info!("received shutdown signal, waiting for {} active tasks to complete", tasks.len()); + + // Give sessions a moment to send GOAWAY messages + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + + // Stop accepting new connections and wait for existing tasks to complete + log::info!("draining {} remaining tasks...", tasks.len()); + let shutdown_timeout = tokio::time::Duration::from_secs(20); + let result = tokio::time::timeout(shutdown_timeout, async { + // Actually poll tasks to completion + while let Some(res) = tasks.next().await { + if let Err(e) = res { + log::warn!("task failed during shutdown: {:?}", e); + } } - }); - if let Err(e) = timeout.await { - log::warn!("timed out waiting for tasks to be empty: {}", e); + }).await; + + match result { + Ok(_) => log::info!("all tasks completed successfully"), + Err(_) => log::warn!("timed out waiting for tasks after {}s", shutdown_timeout.as_secs()), } break Ok(()); } diff --git a/moq-relay-ietf/src/remote.rs b/moq-relay-ietf/src/remote.rs index 636115dc..fcf477db 100644 --- a/moq-relay-ietf/src/remote.rs +++ b/moq-relay-ietf/src/remote.rs @@ -74,18 +74,13 @@ impl RemotesProducer { } /// Run the remotes producer to serve remote requests. - pub async fn run( - mut self, - signal_rx: broadcast::Receiver, - ) -> anyhow::Result<()> { + pub async fn run(mut self) -> anyhow::Result<()> { let mut tasks = FuturesUnordered::new(); loop { tokio::select! { Some(mut remote) = self.next() => { let url = remote.url.clone(); - // Each remote task needs its own receiver - let remote_signal_rx = signal_rx.resubscribe(); // Spawn a task to serve the remote tasks.push(async move { @@ -94,7 +89,7 @@ impl RemotesProducer { log::warn!("serving remote: {:?}", info); // Run the remote producer - if let Err(err) = remote.run(remote_signal_rx).await { + if let Err(err) = remote.run().await { log::warn!("failed serving remote: {:?}, error: {}", info, err); } @@ -232,16 +227,13 @@ impl RemoteProducer { Self { info, state } } - pub async fn run( - &mut self, - signal_rx: broadcast::Receiver, - ) -> anyhow::Result<()> { + pub async fn run(&mut self) -> anyhow::Result<()> { // TODO reuse QUIC and MoQ sessions let (session, _quic_client_initial_cid) = self.quic.connect(&self.url).await?; let (session, subscriber) = moq_transport::session::Subscriber::connect(session).await?; // Run the session - let mut session = session.run(Some(signal_rx)).boxed(); + let mut session = session.run().boxed(); let mut tasks = FuturesUnordered::new(); let mut done = None; diff --git a/moq-relay-ietf/src/session.rs b/moq-relay-ietf/src/session.rs index 3efd1d4a..e8299e45 100644 --- a/moq-relay-ietf/src/session.rs +++ b/moq-relay-ietf/src/session.rs @@ -11,12 +11,9 @@ pub struct Session { impl Session { /// Run the session, producer, and consumer as necessary. - pub async fn run( - self, - signal_rx: broadcast::Receiver, - ) -> Result<(), SessionError> { + pub async fn run(self) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); - tasks.push(self.session.run(Some(signal_rx)).boxed()); + tasks.push(self.session.run().boxed()); if let Some(producer) = self.producer { tasks.push(producer.run().boxed()); diff --git a/moq-sub/src/main.rs b/moq-sub/src/main.rs index 5c163948..e38154f2 100644 --- a/moq-sub/src/main.rs +++ b/moq-sub/src/main.rs @@ -45,7 +45,7 @@ async fn main() -> anyhow::Result<()> { let mut media = Media::new(subscriber, tracks, out, config.catalog).await?; tokio::select! { - res = session.run(None) => res.context("session error")?, + res = session.run() => res.context("session error")?, res = media.run() => res.context("media error")?, } diff --git a/moq-transport/src/session/mod.rs b/moq-transport/src/session/mod.rs index 7d6cdbf3..6afab414 100644 --- a/moq-transport/src/session/mod.rs +++ b/moq-transport/src/session/mod.rs @@ -51,6 +51,9 @@ pub struct Session { /// Optional mlog writer for MoQ Transport events /// Wrapped in Arc> to share across send/recv tasks when enabled mlog: Option>>, + + /// Optional signal receiver for migration events + signal_rx: Option>, } #[derive(Clone, Debug)] @@ -73,6 +76,7 @@ impl Session { recver: Reader, first_requestid: u64, mlog: Option, + signal_rx: Option>, ) -> (Self, Option, Option) { let next_requestid = Arc::new(atomic::AtomicU64::new(first_requestid)); let outgoing = Queue::default().split(); @@ -100,6 +104,7 @@ impl Session { subscriber: subscriber.clone(), outgoing: outgoing.1, mlog: mlog_shared, + signal_rx, }; (session, publisher, subscriber) @@ -110,6 +115,7 @@ impl Session { pub async fn connect( mut session: web_transport::Session, mlog_path: Option, + signal_rx: Option>, ) -> Result<(Session, Publisher, Subscriber), SessionError> { let mlog = mlog_path.and_then(|path| { mlog::MlogWriter::new(path) @@ -142,7 +148,7 @@ impl Session { // TODO: emit server_setup_parsed event // We are the client, so the first request id is 0 - let session = Session::new(session, sender, recver, 0, mlog); + let session = Session::new(session, sender, recver, 0, mlog, signal_rx); Ok((session.0, session.1.unwrap(), session.2.unwrap())) } @@ -151,6 +157,7 @@ impl Session { pub async fn accept( mut session: web_transport::Session, mlog_path: Option, + signal_rx: Option>, ) -> Result<(Session, Option, Option), SessionError> { let mut mlog = mlog_path.and_then(|path| { mlog::MlogWriter::new(path) @@ -195,7 +202,7 @@ impl Session { sender.encode(&server).await?; // We are the server, so the first request id is 1 - Ok(Session::new(session, sender, recver, 1, mlog)) + Ok(Session::new(session, sender, recver, 1, mlog, signal_rx)) } else { Err(SessionError::Version(client.versions, server_versions)) } @@ -204,15 +211,12 @@ impl Session { /// Run Tasks for the session, including sending of control messages, receiving and processing /// inbound control messages, receiving and processing new inbound uni-directional QUIC streams, /// and receiving and processing QUIC datagrams received - pub async fn run( - self, - signal_rx: Option>, - ) -> Result<(), SessionError> { + pub async fn run(self) -> Result<(), SessionError> { let mut cloned_outgoing = self.outgoing.clone(); // Spawn a task that waits for shutdown signal and pushes GOAWAY // This runs independently and doesn't affect the main session tasks - if let Some(mut signal_rx) = signal_rx { + if let Some(mut signal_rx) = self.signal_rx { tokio::spawn(async move { if let Ok(info) = signal_rx.recv().await { log::info!( @@ -233,7 +237,7 @@ impl Session { tokio::select! { res = Self::run_recv(self.recver, self.publisher, self.subscriber.clone(), self.mlog.clone()) => res, - res = Self::run_send(self.sender, self.outgoing, self.mlog.clone()) => res, + res = Self::run_send(self.sender, self.outgoing, self.subscriber.clone(), self.mlog.clone()) => res, res = Self::run_streams(self.webtransport.clone(), self.subscriber.clone()) => res, res = Self::run_datagrams(self.webtransport, self.subscriber) => res, } @@ -243,6 +247,7 @@ impl Session { async fn run_send( mut sender: Writer, mut outgoing: Queue, + mut subscriber: Option, mlog: Option>>, ) -> Result<(), SessionError> { while let Some(msg) = outgoing.pop().await { @@ -289,6 +294,12 @@ impl Session { } } + if let Message::GoAway(_m) = &msg { + subscriber + .iter_mut() + .for_each(|s| s.handle_go_away().unwrap_or(())); + } + sender.encode(&msg).await?; } @@ -379,7 +390,7 @@ impl Session { subscriber .as_mut() .ok_or(SessionError::RoleViolation)? - .handle_go_away(goaway)?; + .handle_go_away()?; continue; } _ => msg, diff --git a/moq-transport/src/session/publisher.rs b/moq-transport/src/session/publisher.rs index 1d0c45b1..c861578a 100644 --- a/moq-transport/src/session/publisher.rs +++ b/moq-transport/src/session/publisher.rs @@ -75,14 +75,14 @@ impl Publisher { pub async fn accept( session: web_transport::Session, ) -> Result<(Session, Publisher), SessionError> { - let (session, publisher, _) = Session::accept(session, None).await?; + let (session, publisher, _) = Session::accept(session, None, None).await?; Ok((session, publisher.unwrap())) } pub async fn connect( session: web_transport::Session, ) -> Result<(Session, Publisher), SessionError> { - let (session, publisher, _) = Session::connect(session, None).await?; + let (session, publisher, _) = Session::connect(session, None, None).await?; Ok((session, publisher)) } diff --git a/moq-transport/src/session/subscribed.rs b/moq-transport/src/session/subscribed.rs index e87961fc..650513ed 100644 --- a/moq-transport/src/session/subscribed.rs +++ b/moq-transport/src/session/subscribed.rs @@ -273,6 +273,9 @@ impl Subscribed { let mut object_count = 0; while let Some(mut subgroup_object_reader) = subgroup_reader.next().await? { + // Check if subscription was cancelled before processing next object + state.lock().closed.clone()?; + let subgroup_object = data::SubgroupObjectExt { object_id_delta: 0, // before delta logic, used to be subgroup_object_reader.object_id, extension_headers: subgroup_object_reader.extension_headers.clone(), // Pass through extension headers @@ -325,6 +328,9 @@ impl Subscribed { let mut chunks_sent = 0; let mut bytes_sent = 0; while let Some(chunk) = subgroup_object_reader.read().await? { + // Check if subscription was cancelled before writing each chunk + state.lock().closed.clone()?; + log::trace!( "[PUBLISHER] serve_subgroup: sending payload chunk #{} for object #{} ({} bytes)", chunks_sent + 1, @@ -363,6 +369,9 @@ impl Subscribed { let mut datagram_count = 0; while let Some(datagram) = datagrams.read().await? { + // Check if subscription was cancelled before sending next datagram + self.state.lock().closed.clone()?; + // Determine datagram type based on extension headers presence let has_extension_headers = !datagram.extension_headers.is_empty(); let datagram_type = if has_extension_headers { diff --git a/moq-transport/src/session/subscriber.rs b/moq-transport/src/session/subscriber.rs index 8f0d2677..7bcac626 100644 --- a/moq-transport/src/session/subscriber.rs +++ b/moq-transport/src/session/subscriber.rs @@ -67,13 +67,13 @@ impl Subscriber { /// Create an inbound/server QUIC connection, by accepting a bi-directional QUIC stream for control messages. pub async fn accept(session: web_transport::Session) -> Result<(Session, Self), SessionError> { - let (session, _, subscriber) = Session::accept(session, None).await?; + let (session, _, subscriber) = Session::accept(session, None, None).await?; Ok((session, subscriber.unwrap())) } /// Create an outbound/client QUIC connection, by opening a bi-directional QUIC stream for control messages. pub async fn connect(session: web_transport::Session) -> Result<(Session, Self), SessionError> { - let (session, _, subscriber) = Session::connect(session, None).await?; + let (session, _, subscriber) = Session::connect(session, None, None).await?; Ok((session, subscriber)) } @@ -161,11 +161,8 @@ impl Subscriber { res } - pub(super) fn handle_go_away(&mut self, goaway: message::GoAway) -> Result<(), SessionError> { - info!( - "Received GOAWAY: {:?}, sending unsubscribe for all active subscriptions", - goaway - ); + pub(super) fn handle_go_away(&mut self) -> Result<(), SessionError> { + info!("sending unsubscribe for all active subscriptions"); // Collect all subscription IDs first to avoid holding the lock while calling remove_subscribe let ids: Vec = {