diff --git a/Cargo.lock b/Cargo.lock index 4949f36..2635604 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -620,6 +620,16 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "795cbfc56d419a7ce47ccbb7504dd9a5b7c484c083c356e797de08bd988d9629" +[[package]] +name = "fs2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "fs_extra" version = "1.3.0" @@ -1282,9 +1292,11 @@ name = "moq-relay-ietf" version = "0.7.5" dependencies = [ "anyhow", + "async-trait", "axum", "clap", "env_logger", + "fs2", "futures", "hex", "hyper-serve", @@ -1292,6 +1304,8 @@ dependencies = [ "moq-api", "moq-native-ietf", "moq-transport", + "serde", + "serde_json", "tokio", "tower-http", "tracing", diff --git a/moq-clock-ietf/src/main.rs b/moq-clock-ietf/src/main.rs index 199116c..7d0f695 100644 --- a/moq-clock-ietf/src/main.rs +++ b/moq-clock-ietf/src/main.rs @@ -29,16 +29,12 @@ async fn main() -> anyhow::Result<()> { let tls = config.tls.load()?; // Create the QUIC endpoint - let quic = quic::Endpoint::new(quic::Config { - bind: config.bind, - qlog_dir: None, - tls, - })?; + let quic = quic::Endpoint::new(quic::Config::new(config.bind, None, tls))?; log::info!("connecting to server: url={}", config.url); // Connect to the server - let (session, connection_id) = quic.client.connect(&config.url).await?; + let (session, connection_id) = quic.client.connect(&config.url, None).await?; log::info!( "connected with CID: {} (use this to look up qlog/mlog on server)", diff --git a/moq-native-ietf/src/quic.rs b/moq-native-ietf/src/quic.rs index 0919fc2..8b6cdad 100644 --- a/moq-native-ietf/src/quic.rs +++ b/moq-native-ietf/src/quic.rs @@ -1,4 +1,6 @@ use std::{ + collections::HashSet, + fmt, fs::File, io::BufWriter, net, @@ -17,6 +19,25 @@ use futures::future::BoxFuture; use futures::stream::{FuturesUnordered, StreamExt}; use futures::FutureExt; +/// Represents the address family of the local QUIC socket. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AddressFamily { + Ipv4, + Ipv6, + /// IPv6 with dual-stack support (Linux) + Ipv6DualStack, +} + +impl fmt::Display for AddressFamily { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AddressFamily::Ipv4 => write!(f, "IPv4"), + AddressFamily::Ipv6 => write!(f, "IPv6"), + AddressFamily::Ipv6DualStack => write!(f, "IPv6 (dual stack)"), + } + } +} + /// Build a TransportConfig with our standard settings /// /// This is used both for the base endpoint config and when creating @@ -57,23 +78,60 @@ impl Default for Args { impl Args { pub fn load(&self) -> anyhow::Result { let tls = self.tls.load()?; - Ok(Config { - bind: self.bind, - qlog_dir: self.qlog_dir.clone(), - tls, - }) + Ok(Config::new(self.bind, self.qlog_dir.clone(), tls)) } } pub struct Config { - pub bind: net::SocketAddr, + pub bind: Option, + pub socket: net::UdpSocket, pub qlog_dir: Option, pub tls: tls::Config, + pub tags: HashSet, +} + +impl Config { + pub fn new(bind: net::SocketAddr, qlog_dir: Option, tls: tls::Config) -> Self { + Self { + bind: Some(bind), + socket: net::UdpSocket::bind(bind) + .context("failed to bind socket") + .unwrap(), + qlog_dir, + tls, + tags: HashSet::new(), + } + } + + pub fn with_socket( + socket: net::UdpSocket, + qlog_dir: Option, + tls: tls::Config, + ) -> Self { + Self { + bind: None, + socket, + qlog_dir, + tls, + tags: HashSet::new(), + } + } + + pub fn with_tag(mut self, tag: String) -> Self { + self.tags.insert(tag); + self + } } pub struct Endpoint { pub client: Client, pub server: Option, + /// Tags associated with this endpoint + /// These are used to filter endpoints for different purposes, for eg- + /// "server" tag is used to filter endpoints for relay server + /// "forward" tag is used to filter endpoints for forwarder + /// This is upto the user to define and use + pub tags: HashSet, } impl Endpoint { @@ -111,13 +169,13 @@ impl Endpoint { // There's a bit more boilerplate to make a generic endpoint. let runtime = quinn::default_runtime().context("no async runtime")?; let endpoint_config = quinn::EndpointConfig::default(); - let socket = std::net::UdpSocket::bind(config.bind).context("failed to bind UDP socket")?; + let socket = config.socket; // Create the generic QUIC endpoint. let quic = quinn::Endpoint::new(endpoint_config, server_config.clone(), socket, runtime) .context("failed to create QUIC endpoint")?; - let server = server_config.clone().map(|base_server_config| Server { + let server = server_config.map(|base_server_config| Server { quic: quic.clone(), accept: Default::default(), qlog_dir: config.qlog_dir.map(Arc::new), @@ -130,7 +188,11 @@ impl Endpoint { transport, }; - Ok(Self { client, server }) + Ok(Self { + client, + server, + tags: config.tags, + }) } } @@ -270,7 +332,34 @@ pub struct Client { } impl Client { - pub async fn connect(&self, url: &Url) -> anyhow::Result<(web_transport::Session, String)> { + /// Returns the local address of the QUIC socket. + pub fn local_addr(&self) -> anyhow::Result { + self.quic + .local_addr() + .context("failed to get local address") + } + + /// Returns the address family of the local QUIC socket. + pub fn address_family(&self) -> anyhow::Result { + let local_addr = self + .quic + .local_addr() + .context("failed to get local socket address")?; + + if local_addr.is_ipv4() { + Ok(AddressFamily::Ipv4) + } else if cfg!(target_os = "linux") { + Ok(AddressFamily::Ipv6DualStack) + } else { + Ok(AddressFamily::Ipv6) + } + } + + pub async fn connect( + &self, + url: &Url, + socket_addr: Option, + ) -> anyhow::Result<(web_transport::Session, String)> { let mut config = self.config.clone(); // TODO support connecting to both ALPNs at the same time @@ -303,12 +392,15 @@ impl Client { let host = url.host().context("invalid DNS name")?.to_string(); let port = url.port().unwrap_or(443); - // Look up the DNS entry. - let addr = tokio::net::lookup_host((host.clone(), port)) - .await - .context("failed DNS lookup")? - .next() - .context("no DNS entries")?; + // Look up the DNS entry and filter by socket address family. + let addr = match socket_addr { + Some(addr) => addr, + None => { + // Default DNS resolution logic + self.resolve_dns(&host, port, self.address_family()?) + .await? + } + }; let connection = self.quic.connect_with(config, addr, &host)?.await?; @@ -328,4 +420,83 @@ impl Client { Ok((session.into(), connection_id_hex)) } + + /// Default DNS resolution logic that filters results by address family. + async fn resolve_dns( + &self, + host: &str, + port: u16, + address_family: AddressFamily, + ) -> anyhow::Result { + let local_addr = self.local_addr()?; + + // Collect all DNS results + let addrs: Vec = tokio::net::lookup_host((host, port)) + .await + .context("failed DNS lookup")? + .collect(); + + if addrs.is_empty() { + anyhow::bail!("DNS lookup for host '{}' returned no addresses", host); + } + + // Log all DNS results for debugging + log::debug!( + "DNS lookup for {}, family {:?}: found {} results", + host, + address_family, + addrs.len() + ); + for (i, addr) in addrs.iter().enumerate() { + log::debug!( + " DNS[{}]: {} ({})", + i, + addr, + if addr.is_ipv4() { "IPv4" } else { "IPv6" } + ); + } + + // Filter DNS results to match our local socket's address family + let compatible_addr = match address_family { + AddressFamily::Ipv4 => { + // IPv4 socket: filter to IPv4 addresses + addrs + .iter() + .find(|a| a.is_ipv4()) + .cloned() + .context(format!( + "No IPv4 address found for host '{}' (local socket is IPv4: {})", + host, local_addr + ))? + } + AddressFamily::Ipv6DualStack => { + // IPv6 socket on Linux: dual-stack, use first result + log::debug!( + "Using first DNS result (Linux IPv6 dual-stack): {}", + addrs[0] + ); + addrs[0] + } + AddressFamily::Ipv6 => { + // IPv6 socket non-Linux: filter to IPv6 addresses + addrs + .iter() + .find(|a| a.is_ipv6()) + .cloned() + .context(format!( + "No IPv6 address found for host '{}' (local socket is IPv6: {})", + host, local_addr + ))? + } + }; + + log::debug!( + "Connecting from {} to {} (selected from {} DNS results)", + local_addr, + compatible_addr, + addrs.len() + ); + + Ok(compatible_addr) + } } diff --git a/moq-pub/src/main.rs b/moq-pub/src/main.rs index 7e71764..cd9350f 100644 --- a/moq-pub/src/main.rs +++ b/moq-pub/src/main.rs @@ -57,14 +57,14 @@ async fn main() -> anyhow::Result<()> { let tls = cli.tls.load()?; - let quic = quic::Endpoint::new(moq_native_ietf::quic::Config { - bind: cli.bind, - qlog_dir: None, - tls: tls.clone(), - })?; + let quic = quic::Endpoint::new(moq_native_ietf::quic::Config::new( + cli.bind, + None, + tls.clone(), + ))?; log::info!("connecting to relay: url={}", cli.url); - let (session, connection_id) = quic.client.connect(&cli.url).await?; + let (session, connection_id) = quic.client.connect(&cli.url, None).await?; log::info!( "connected with CID: {} (use this to look up qlog/mlog on server)", diff --git a/moq-relay-ietf/Cargo.toml b/moq-relay-ietf/Cargo.toml index cd3e4d8..13ed459 100644 --- a/moq-relay-ietf/Cargo.toml +++ b/moq-relay-ietf/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "moq-relay-ietf" description = "Media over QUIC" -authors = ["Luke Curley"] -repository = "https://github.com/englishm/moq-rs" +authors = ["Luke Curley", "Manish Kumar Pandit"] +repository = "https://github.com/cloudflare/moq-rs" license = "MIT OR Apache-2.0" version = "0.7.5" @@ -11,10 +11,19 @@ edition = "2021" keywords = ["quic", "http3", "webtransport", "media", "live"] categories = ["multimedia", "network-programming", "web-programming"] +[lib] +name = "moq_relay_ietf" +path = "src/lib.rs" + +[[bin]] +name = "moq-relay-ietf" +path = "src/bin/moq-relay-ietf/main.rs" + [dependencies] moq-transport = { path = "../moq-transport", version = "0.11" } moq-native-ietf = { path = "../moq-native-ietf", version = "0.5" } moq-api = { path = "../moq-api", version = "0.2" } +web-transport = { workspace = true } # QUIC url = "2" @@ -22,6 +31,7 @@ url = "2" # Async stuff tokio = { version = "1", features = ["full"] } futures = "0.3" +async-trait = "0.1" # Web server to serve the fingerprint axum = { version = "0.7", features = ["tokio"] } @@ -31,6 +41,13 @@ hyper-serve = { version = "0.6", features = [ tower-http = { version = "0.5", features = ["cors"] } hex = "0.4" +# Serialization +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +# File locking +fs2 = "0.4" + # Error handling anyhow = { version = "1", features = ["backtrace"] } @@ -42,3 +59,4 @@ log = { workspace = true } env_logger = { workspace = true } tracing = "0.1" tracing-subscriber = "0.3" +thiserror = "2.0.17" diff --git a/moq-relay-ietf/src/bin/moq-relay-ietf/file_coordinator.rs b/moq-relay-ietf/src/bin/moq-relay-ietf/file_coordinator.rs new file mode 100644 index 0000000..8e40369 --- /dev/null +++ b/moq-relay-ietf/src/bin/moq-relay-ietf/file_coordinator.rs @@ -0,0 +1,252 @@ +//! File-based coordinator for multi-relay deployments. +//! +//! This coordinator uses a shared JSON file with file locking to coordinate +//! namespace registration across multiple relay instances. No separate +//! server process is required. + +use std::collections::HashMap; +use std::fs::{File, OpenOptions}; +use std::io::{Read, Seek, SeekFrom, Write}; +use std::path::{Path, PathBuf}; + +use anyhow::{Context, Result}; +use async_trait::async_trait; +use fs2::FileExt; +use moq_native_ietf::quic::Client; +use moq_transport::coding::TrackNamespace; +use serde::{Deserialize, Serialize}; +use url::Url; + +use moq_relay_ietf::{ + Coordinator, CoordinatorError, CoordinatorResult, NamespaceOrigin, NamespaceRegistration, +}; + +/// Data stored in the shared file +#[derive(Debug, Default, Serialize, Deserialize)] +struct CoordinatorData { + /// Maps namespace path (e.g., "/foo/bar") to relay URL + namespaces: HashMap, +} + +impl CoordinatorData { + fn namespace_key(namespace: &TrackNamespace) -> String { + namespace.to_utf8_path() + } +} + +/// Handle that unregisters a namespace when dropped +struct NamespaceUnregisterHandle { + namespace: TrackNamespace, + file_path: PathBuf, +} + +impl Drop for NamespaceUnregisterHandle { + fn drop(&mut self) { + if let Err(err) = unregister_namespace_sync(&self.file_path, &self.namespace) { + log::warn!("failed to unregister namespace on drop: {}", err); + } + } +} + +/// Synchronous helper for unregistering namespace (used in Drop) +fn unregister_namespace_sync(file_path: &Path, namespace: &TrackNamespace) -> Result<()> { + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(false) + .open(file_path)?; + + file.lock_exclusive()?; + + let mut data = read_data(&file)?; + let key = CoordinatorData::namespace_key(namespace); + + log::debug!("unregistering namespace: {}", key); + data.namespaces.remove(&key); + + write_data(&file, &data)?; + file.unlock()?; + + Ok(()) +} + +/// Read coordinator data from file +fn read_data(file: &File) -> Result { + let mut file = file; + file.seek(SeekFrom::Start(0))?; + + let mut contents = String::new(); + file.read_to_string(&mut contents)?; + + if contents.is_empty() { + return Ok(CoordinatorData::default()); + } + + serde_json::from_str(&contents).context("failed to parse coordinator data") +} + +/// Write coordinator data to file +fn write_data(file: &File, data: &CoordinatorData) -> Result<()> { + let mut file = file; + file.seek(SeekFrom::Start(0))?; + file.set_len(0)?; + + let json = serde_json::to_string_pretty(data)?; + file.write_all(json.as_bytes())?; + file.flush()?; + + Ok(()) +} + +/// A coordinator that uses a shared file for state storage. +/// +/// Multiple relay instances can use the same file to share namespace/track +/// registration data. File locking ensures safe concurrent access. +pub struct FileCoordinator { + /// Path to the shared coordination file + file_path: PathBuf, + /// URL of this relay (used when registering namespaces) + relay_url: Url, +} + +impl FileCoordinator { + /// Create a new file-based coordinator. + /// + /// # Arguments + /// * `file_path` - Path to the shared coordination file + /// * `relay_url` - URL of this relay instance (advertised to other relays) + pub fn new(file_path: impl AsRef, relay_url: Url) -> Self { + Self { + file_path: file_path.as_ref().to_path_buf(), + relay_url, + } + } +} + +#[async_trait] +impl Coordinator for FileCoordinator { + async fn register_namespace( + &self, + namespace: &TrackNamespace, + ) -> CoordinatorResult { + let namespace = namespace.clone(); + let relay_url = self.relay_url.to_string(); + let file_path = self.file_path.clone(); + + // Run blocking file I/O in a separate thread + let ns_clone = namespace.clone(); + tokio::task::spawn_blocking(move || { + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(false) + .open(&file_path)?; + + file.lock_exclusive()?; + + let mut data = read_data(&file)?; + let key = CoordinatorData::namespace_key(&ns_clone); + + log::info!("registering namespace: {} -> {}", key, relay_url); + data.namespaces.insert(key, relay_url); + + write_data(&file, &data)?; + file.unlock()?; + + Ok::<_, anyhow::Error>(()) + }) + .await??; + + let handle = NamespaceUnregisterHandle { + namespace, + file_path: self.file_path.clone(), + }; + + Ok(NamespaceRegistration::new(handle)) + } + + // FIXME(itzmanish): Not being called currently but we need to call this on publish_namespace_done + // currently unregister happens on drop of namespace + async fn unregister_namespace(&self, namespace: &TrackNamespace) -> CoordinatorResult<()> { + let namespace = namespace.clone(); + let file_path = self.file_path.clone(); + + tokio::task::spawn_blocking(move || unregister_namespace_sync(&file_path, &namespace)) + .await??; + + Ok(()) + } + + async fn lookup( + &self, + namespace: &TrackNamespace, + ) -> CoordinatorResult<(NamespaceOrigin, Option)> { + let namespace = namespace.clone(); + let file_path = self.file_path.clone(); + + let result = tokio::task::spawn_blocking( + move || -> Result)>> { + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(false) + .open(&file_path)?; + + file.lock_shared()?; + + let data = read_data(&file)?; + let key = CoordinatorData::namespace_key(&namespace); + + log::debug!("looking up namespace: {}", key); + + // Try exact match first + if let Some(relay_url) = data.namespaces.get(&key) { + file.unlock()?; + let url = Url::parse(relay_url)?; + return Ok(Some((NamespaceOrigin::new(namespace, url), None))); + } + + // Try prefix matching (find longest matching prefix) + let mut best_match: Option<(String, String)> = None; + for (registered_key, url) in &data.namespaces { + // FIXME(itzmanish): it would be much better to compare on TupleField + // instead of working on strings + let is_prefix = registered_key + .split('/') + .zip(key.split('/')) + .all(|(a, b)| a == b); + match best_match { + Some((ns, _)) if is_prefix && ns.len() < registered_key.len() => { + best_match = Some((registered_key.clone(), url.clone())); + } + None if is_prefix => { + best_match = Some((registered_key.clone(), url.clone())); + } + _ => {} + } + } + + file.unlock()?; + + if let Some((matched_key, relay_url)) = best_match { + let matched_ns = TrackNamespace::from_utf8_path(&matched_key); + let url = Url::parse(&relay_url)?; + return Ok(Some((NamespaceOrigin::new(matched_ns, url), None))); + } + + Ok(None) + }, + ) + .await??; + + result.ok_or(CoordinatorError::NamespaceNotFound) + } + + async fn shutdown(&self) -> CoordinatorResult<()> { + // Nothing to clean up - file will be unlocked automatically + Ok(()) + } +} diff --git a/moq-relay-ietf/src/main.rs b/moq-relay-ietf/src/bin/moq-relay-ietf/main.rs similarity index 73% rename from moq-relay-ietf/src/main.rs rename to moq-relay-ietf/src/bin/moq-relay-ietf/main.rs index 995b28c..533d40e 100644 --- a/moq-relay-ietf/src/main.rs +++ b/moq-relay-ietf/src/bin/moq-relay-ietf/main.rs @@ -1,26 +1,14 @@ -use clap::Parser; - -mod api; -mod consumer; -mod local; -mod producer; -mod relay; -mod remote; -mod session; -mod web; - -pub use api::*; -pub use consumer::*; -pub use local::*; -pub use producer::*; -pub use relay::*; -pub use remote::*; -pub use session::*; -pub use web::*; +mod file_coordinator; +use std::sync::Arc; use std::{net, path::PathBuf}; + +use clap::Parser; use url::Url; +use file_coordinator::FileCoordinator; +use moq_relay_ietf::{Relay, RelayConfig, Web, WebConfig}; + #[derive(Parser, Clone)] pub struct Cli { /// Listen on this address @@ -68,6 +56,15 @@ pub struct Cli { /// Requires --dev to enable the web server. Only serves files by exact CID - no index. #[arg(long)] pub mlog_serve: bool, + + /// Path to the shared coordinator file for multi-relay coordination. + /// Multiple relay instances can share namespace/track registration via this file. + /// User doesn't have to explicitly create and populate anything. This path will be + /// used by file coordinator to store namespace/track registration information. + /// User need to make sure if multiple relay's are being used all of them have same path + /// to this file. + #[arg(long, default_value = "/tmp/moq-coordinator.json")] + pub coordinator_file: PathBuf, } #[tokio::main] @@ -103,15 +100,27 @@ async fn main() -> anyhow::Result<()> { None }; + // Build the relay URL from the node or bind address + let relay_url = cli + .node + .clone() + .unwrap_or_else(|| Url::parse(&format!("https://{}", cli.bind)).unwrap()); + + // Create the file-based coordinator for multi-relay coordination + let coordinator = Arc::new(FileCoordinator::new(&cli.coordinator_file, relay_url)); + + log::info!("using file coordinator: {}", cli.coordinator_file.display()); + // Create a QUIC server for media. let relay = Relay::new(RelayConfig { tls: tls.clone(), - bind: cli.bind, + bind: Some(cli.bind), + endpoints: vec![], qlog_dir: qlog_dir_for_relay, mlog_dir: mlog_dir_for_relay, node: cli.node, - api: cli.api, announce: cli.announce, + coordinator, })?; if cli.dev { diff --git a/moq-relay-ietf/src/consumer.rs b/moq-relay-ietf/src/consumer.rs index 85242d0..a93d112 100644 --- a/moq-relay-ietf/src/consumer.rs +++ b/moq-relay-ietf/src/consumer.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use anyhow::Context; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use moq_transport::{ @@ -5,28 +7,28 @@ use moq_transport::{ session::{Announced, SessionError, Subscriber}, }; -use crate::{Api, Locals, Producer}; +use crate::{Coordinator, Locals, Producer}; /// Consumer of tracks from a remote Publisher #[derive(Clone)] pub struct Consumer { - remote: Subscriber, + subscriber: Subscriber, locals: Locals, - api: Option, + coordinator: Arc, forward: Option, // Forward all announcements to this subscriber } impl Consumer { pub fn new( - remote: Subscriber, + subscriber: Subscriber, locals: Locals, - api: Option, + coordinator: Arc, forward: Option, ) -> Self { Self { - remote, + subscriber, locals, - api, + coordinator, forward, } } @@ -38,7 +40,7 @@ impl Consumer { loop { tokio::select! { // Handle a new announce request - Some(announce) = self.remote.announced() => { + Some(announce) = self.subscriber.announced() => { let this = self.clone(); tasks.push(async move { @@ -64,13 +66,11 @@ impl Consumer { // Produce the tracks for this announce and return the reader let (_, mut request, reader) = Tracks::new(announce.namespace.clone()).produce(); - // Start refreshing the API origin, if any - if let Some(api) = self.api.as_ref() { - let mut refresh = api.set_origin(reader.namespace.to_utf8_path()).await?; - tasks.push( - async move { refresh.run().await.context("failed refreshing origin") }.boxed(), - ); - } + // Register namespace with the coordinator + let _namespace_registration = self + .coordinator + .register_namespace(&reader.namespace) + .await?; // Register the local tracks, unregister on drop let _register = self.locals.register(reader.clone()).await?; @@ -100,7 +100,7 @@ impl Consumer { // Wait for the next subscriber and serve the track. Some(track) = request.next() => { - let mut remote = self.remote.clone(); + let mut subscriber = self.subscriber.clone(); // Spawn a new task to handle the subscribe tasks.push(async move { @@ -108,7 +108,7 @@ impl Consumer { log::info!("forwarding subscribe: {:?}", info); // Forward the subscribe request - if let Err(err) = remote.subscribe(track).await { + if let Err(err) = subscriber.subscribe(track).await { log::warn!("failed forwarding subscribe: {:?}, error: {}", info, err) } diff --git a/moq-relay-ietf/src/coordinator.rs b/moq-relay-ietf/src/coordinator.rs new file mode 100644 index 0000000..b84e5b6 --- /dev/null +++ b/moq-relay-ietf/src/coordinator.rs @@ -0,0 +1,188 @@ +use async_trait::async_trait; +use moq_native_ietf::quic; +use moq_transport::coding::TrackNamespace; +use url::Url; + +#[derive(Debug, thiserror::Error)] +pub enum CoordinatorError { + #[error("namespace not found")] + NamespaceNotFound, + + #[error("namespace already registered")] + NamespaceAlreadyRegistered, + + #[error("Internal Error: {0}")] + Other(anyhow::Error), +} + +impl From for CoordinatorError { + fn from(err: anyhow::Error) -> Self { + Self::Other(err) + } +} + +impl From for CoordinatorError { + fn from(err: tokio::task::JoinError) -> Self { + Self::Other(err.into()) + } +} + +impl From for CoordinatorError { + fn from(err: std::io::Error) -> Self { + Self::Other(err.into()) + } +} + +pub type CoordinatorResult = std::result::Result; + +/// Handle returned when a namespace is registered with the coordinator. +/// +/// Dropping this handle automatically unregisters the namespace. +/// This provides RAII-based cleanup - when the publisher disconnects +/// or the namespace is no longer served, cleanup happens automatically. +pub struct NamespaceRegistration { + _inner: Box, + _metadata: Option>, +} + +impl NamespaceRegistration { + /// Create a new registration handle wrapping any Send + Sync type. + /// + /// The wrapped value's `Drop` implementation will be called when + /// this registration is dropped. + pub fn new(inner: T) -> Self { + Self { + _inner: Box::new(inner), + _metadata: None, + } + } + + /// Add metadata as list of key value pair of string: string + pub fn with_metadata(mut self, metadata: Vec<(String, String)>) -> Self { + self._metadata = Some(metadata); + self + } +} + +/// Result of a namespace lookup. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NamespaceOrigin { + namespace: TrackNamespace, + url: Url, + metadata: Option>, +} + +impl NamespaceOrigin { + /// Create a new NamespaceOrigin. + pub fn new(namespace: TrackNamespace, url: Url) -> Self { + Self { + namespace, + url, + metadata: None, + } + } + + pub fn with_metadata(mut self, values: (String, String)) -> Self { + if let Some(metadata) = &mut self.metadata { + metadata.push(values); + } else { + self.metadata = Some(vec![values]); + } + self + } + + /// Get the namespace. + pub fn namespace(&self) -> &TrackNamespace { + &self.namespace + } + + /// Get the URL of the relay serving this namespace. + pub fn url(&self) -> Url { + self.url.clone() + } + + /// Get the metadata associated with this namespace. + pub fn metadata(&self) -> Option> { + self.metadata.clone() + } +} + +/// Coordinator handles namespace registration/discovery across relays. +/// +/// Implementations are responsible for: +/// - Tracking which namespaces are served locally +/// - Caching remote namespace lookups +/// - Communicating with external registries (HTTP API, Redis, etc.) +/// - Periodic refresh/heartbeat of registrations +/// - Cleanup when registrations are dropped +/// +/// # Thread Safety +/// +/// All methods take `&self` and implementations must be thread-safe. +/// Multiple tasks will call these methods concurrently. +#[async_trait] +pub trait Coordinator: Send + Sync { + /// Register a namespace as locally available on this relay. + /// + /// Called when a publisher sends PUBLISH_NAMESPACE. + /// The coordinator should: + /// 1. Record the namespace as locally available + /// 2. Advertise to external registry if configured + /// 3. Start any refresh/heartbeat tasks + /// 4. Return a handle that unregisters on drop + /// + /// # Arguments + /// + /// * `namespace` - The namespace being registered + /// + /// # Returns + /// + /// A `NamespaceRegistration` handle. The namespace remains registered + /// as long as this handle is held. Dropping it unregisters the namespace. + async fn register_namespace( + &self, + namespace: &TrackNamespace, + ) -> CoordinatorResult; + + /// Unregister a namespace. + /// + /// Called when a publisher sends PUBLISH_NAMESPACE_DONE. + /// This is an explicit unregistration - the registration handle may still exist + /// but the namespace should be removed from the registry. + /// + /// # Arguments + /// + /// * `namespace` - The namespace to unregister + async fn unregister_namespace(&self, namespace: &TrackNamespace) -> CoordinatorResult<()>; + + /// Lookup where a namespace is served from. + /// + /// Called when a subscriber requests a namespace. + /// The coordinator should check in order: + /// 1. Local registrations (return `Local`) + /// 2. Cached remote lookups (return `Remote(url)` if not expired) + /// 3. External registry (cache and return result) + /// + /// # Arguments + /// + /// * `namespace` - The namespace to look up + /// + /// # Returns + /// + /// - `Ok(NamespaceOrigin, Option)` - Namespace origin and optional client if available + /// - `Err` - Namespace not found anywhere + async fn lookup( + &self, + namespace: &TrackNamespace, + ) -> CoordinatorResult<(NamespaceOrigin, Option)>; + + /// Graceful shutdown of the coordinator. + /// + /// Called when the relay is shutting down. Implementations should: + /// - Unregister all local namespaces and tracks + /// - Cancel refresh tasks + /// - Close connections to external registries + async fn shutdown(&self) -> CoordinatorResult<()> { + Ok(()) + } +} diff --git a/moq-relay-ietf/src/lib.rs b/moq-relay-ietf/src/lib.rs new file mode 100644 index 0000000..aac3932 --- /dev/null +++ b/moq-relay-ietf/src/lib.rs @@ -0,0 +1,49 @@ +//! MoQ Relay library for building Media over QUIC relay servers. +//! +//! This crate provides the core relay functionality that can be embedded +//! into other applications. The relay handles: +//! +//! - Accepting QUIC connections from publishers and subscribers +//! - Routing media between local and remote endpoints +//! - Coordinating namespace/track registration across relay clusters +//! +//! # Example +//! +//! ```rust,ignore +//! use std::sync::Arc; +//! use moq_relay_ietf::{Relay, RelayConfig, FileCoordinator}; +//! +//! // Create a coordinator (FileCoordinator for multi-relay deployments) +//! let coordinator = FileCoordinator::new("/path/to/coordination/file", "https://relay.example.com"); +//! +//! // Configure and create the relay +//! let relay = Relay::new(RelayConfig { +//! bind: "[::]:443".parse().unwrap(), +//! tls: tls_config, +//! coordinator, +//! // ... other options +//! })?; +//! +//! // Run the relay +//! relay.run().await?; +//! ``` + +mod api; +mod consumer; +mod coordinator; +mod local; +mod producer; +mod relay; +mod remote; +mod session; +mod web; + +pub use api::*; +pub use consumer::*; +pub use coordinator::*; +pub use local::*; +pub use producer::*; +pub use relay::*; +pub use remote::*; +pub use session::*; +pub use web::*; diff --git a/moq-relay-ietf/src/local.rs b/moq-relay-ietf/src/local.rs index 6f355a0..406e665 100644 --- a/moq-relay-ietf/src/local.rs +++ b/moq-relay-ietf/src/local.rs @@ -46,9 +46,9 @@ impl Locals { Ok(registration) } - /// Lookup local tracks by namespace using hierarchical prefix matching. + /// Retrieve local tracks by namespace using hierarchical prefix matching. /// Returns the TracksReader for the longest matching namespace prefix. - pub fn route(&self, namespace: &TrackNamespace) -> Option { + pub fn retrieve(&self, namespace: &TrackNamespace) -> Option { let lookup = self.lookup.lock().unwrap(); // Find the longest matching prefix diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index 3de13be..4bd8e17 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -4,20 +4,20 @@ use moq_transport::{ session::{Publisher, SessionError, Subscribed, TrackStatusRequested}, }; -use crate::{Locals, RemotesConsumer}; +use crate::{Locals, RemoteManager}; /// Producer of tracks to a remote Subscriber #[derive(Clone)] pub struct Producer { - remote_publisher: Publisher, + publisher: Publisher, locals: Locals, - remotes: Option, + remotes: RemoteManager, } impl Producer { - pub fn new(remote: Publisher, locals: Locals, remotes: Option) -> Self { + pub fn new(publisher: Publisher, locals: Locals, remotes: RemoteManager) -> Self { Self { - remote_publisher: remote, + publisher, locals, remotes, } @@ -25,7 +25,7 @@ impl Producer { /// Announce new tracks to the remote server. pub async fn announce(&mut self, tracks: TracksReader) -> Result<(), SessionError> { - self.remote_publisher.announce(tracks).await + self.publisher.announce(tracks).await } /// Run the producer to serve subscribe requests. @@ -35,12 +35,12 @@ impl Producer { FuturesUnordered::new(); loop { - let mut remote_publisher_subscribed = self.remote_publisher.clone(); - let mut remote_publisher_track_status = self.remote_publisher.clone(); + let mut publisher_subscribed = self.publisher.clone(); + let mut publisher_track_status = self.publisher.clone(); tokio::select! { // Handle a new subscribe request - Some(subscribed) = remote_publisher_subscribed.subscribed() => { + Some(subscribed) = publisher_subscribed.subscribed() => { let this = self.clone(); // Spawn a new task to handle the subscribe @@ -50,12 +50,12 @@ impl Producer { // Serve the subscribe request if let Err(err) = this.serve_subscribe(subscribed).await { - log::warn!("failed serving subscribe: {:?}, error: {}", info, err) + log::warn!("failed serving subscribe: {:?}, error: {}", info, err); } }.boxed()) }, // Handle a new track_status request - Some(track_status_requested) = remote_publisher_track_status.track_status_requested() => { + Some(track_status_requested) = publisher_track_status.track_status_requested() => { let this = self.clone(); // Spawn a new task to handle the track_status request @@ -77,44 +77,35 @@ impl Producer { /// Serve a subscribe request. async fn serve_subscribe(self, subscribed: Subscribed) -> Result<(), anyhow::Error> { + let namespace = subscribed.track_namespace.clone(); + let track_name = subscribed.track_name.clone(); + // Check local tracks first, and serve from local if possible - if let Some(mut local) = self.locals.route(&subscribed.track_namespace) { + if let Some(mut local) = self.locals.retrieve(&namespace) { // Pass the full requested namespace, not the announced prefix - if let Some(track) = - local.subscribe(subscribed.track_namespace.clone(), &subscribed.track_name) - { + if let Some(track) = local.subscribe(namespace.clone(), &track_name) { log::info!("serving subscribe from local: {:?}", track.info); return Ok(subscribed.serve(track).await?); } } // Check remote tracks second, and serve from remote if possible - if let Some(remotes) = &self.remotes { - // Try to route to a remote for this namespace - if let Some(remote) = remotes.route(&subscribed.track_namespace).await? { - if let Some(track) = remote.subscribe( - subscribed.track_namespace.clone(), - subscribed.track_name.clone(), - )? { - log::info!( - "serving subscribe from remote: {:?} {:?}", - remote.info, - track.info - ); - - // NOTE: Depends on drop(track) being called afterwards - return Ok(subscribed.serve(track.reader).await?); - } - } + if let Some(track) = self + .remotes + .subscribe(namespace.clone(), track_name.clone()) + .await? + { + log::info!("serving subscribe from remote: {:?}", track.info); + return Ok(subscribed.serve(track).await?); } - let namespace = subscribed.track_namespace.clone(); - let name = subscribed.track_name.clone(); - Err(ServeError::not_found_ctx(format!( + // Track not found - close the subscription with not found error + let err = ServeError::not_found_ctx(format!( "track '{}/{}' not found in local or remote tracks", - namespace, name - )) - .into()) + namespace, track_name + )); + subscribed.close(err.clone())?; + Err(err.into()) } /// Serve a track_status request. @@ -125,7 +116,7 @@ impl Producer { // Check local tracks first, and serve from local if possible if let Some(mut local_tracks) = self .locals - .route(&track_status_requested.request_msg.track_namespace) + .retrieve(&track_status_requested.request_msg.track_namespace) { if let Some(track) = local_tracks.get_track_reader( &track_status_requested.request_msg.track_namespace, diff --git a/moq-relay-ietf/src/relay.rs b/moq-relay-ietf/src/relay.rs index 41a6823..662e5a1 100644 --- a/moq-relay-ietf/src/relay.rs +++ b/moq-relay-ietf/src/relay.rs @@ -1,17 +1,32 @@ -use std::{net, path::PathBuf}; +use std::{future::Future, net, path::PathBuf, pin::Pin, sync::Arc}; use anyhow::Context; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; -use moq_native_ietf::quic; +use moq_native_ietf::quic::{self, Endpoint}; use url::Url; -use crate::{Api, Consumer, Locals, Producer, Remotes, RemotesConsumer, RemotesProducer, Session}; +use crate::{Consumer, Coordinator, Locals, Producer, RemoteManager, Session}; + +// A type alias for boxed future +type ServerFuture = Pin< + Box< + dyn Future< + Output = ( + anyhow::Result<(web_transport::Session, String)>, + quic::Server, + ), + >, + >, +>; /// Configuration for the relay. pub struct RelayConfig { /// Listen on this address - pub bind: net::SocketAddr, + pub bind: Option, + + /// Optional list of endpoints if provided, we won't use bind + pub endpoints: Vec, /// The TLS configuration. pub tls: moq_native_ietf::tls::Config, @@ -25,32 +40,44 @@ pub struct RelayConfig { /// Forward all announcements to the (optional) URL. pub announce: Option, - /// Connect to the HTTP moq-api at this URL. - pub api: Option, - /// Our hostname which we advertise to other origins. /// We use QUIC, so the certificate must be valid for this address. pub node: Option, + + /// The coordinator for namespace/track registration and discovery. + pub coordinator: Arc, } /// MoQ Relay server. pub struct Relay { - quic: quic::Endpoint, + quic_endpoints: Vec, announce_url: Option, mlog_dir: Option, locals: Locals, - api: Option, - remotes: Option<(RemotesProducer, RemotesConsumer)>, + remotes: RemoteManager, + coordinator: Arc, } impl Relay { pub fn new(config: RelayConfig) -> anyhow::Result { - // Create a QUIC endpoint that can be used for both clients and servers. - let quic = quic::Endpoint::new(quic::Config { - bind: config.bind, - qlog_dir: config.qlog_dir, - tls: config.tls, - })?; + if config.bind.is_some() && !config.endpoints.is_empty() { + anyhow::bail!("cannot specify both bind and endpoints"); + } + + let endpoints = if config.bind.is_some() { + let endpoint = quic::Endpoint::new(quic::Config::new( + config.bind.unwrap(), + config.qlog_dir.clone(), + config.tls.clone(), + ))?; + vec![endpoint] + } else { + config.endpoints + }; + + if endpoints.is_empty() { + anyhow::bail!("no endpoints available to start the server"); + } // Validate mlog directory if provided if let Some(mlog_dir) = &config.mlog_dir { @@ -63,32 +90,24 @@ impl Relay { log::info!("mlog output enabled: {}", mlog_dir.display()); } - // Create an API client if we have the necessary configuration - let api = if let (Some(url), Some(node)) = (config.api, config.node) { - log::info!("using moq-api: url={} node={}", url, node); - Some(Api::new(url, node)) - } else { - None - }; - let locals = Locals::new(); - // Create remotes if we have an API client - let remotes = api.clone().map(|api| { - Remotes { - api, - quic: quic.client.clone(), - } - .produce() - }); + // FIXME(itzmanish): have a generic filter to find endpoints for forward, remote etc. + let remote_clients = endpoints + .iter() + .map(|endpoint| endpoint.client.clone()) + .collect::>(); + + // Create remote manager - uses coordinator for namespace lookups + let remotes = RemoteManager::new(config.coordinator.clone(), remote_clients)?; Ok(Self { - quic, + quic_endpoints: endpoints, announce_url: config.announce, mlog_dir: config.mlog_dir, - api, locals, remotes, + coordinator: config.coordinator, }) } @@ -96,21 +115,16 @@ impl Relay { pub async fn run(self) -> anyhow::Result<()> { let mut tasks = FuturesUnordered::new(); - // Start the remotes producer task, if any - let remotes = self.remotes.map(|(producer, consumer)| { - tasks.push(producer.run().boxed()); - consumer - }); + let remotes = self.remotes.clone(); // Start the forwarder, if any let forward_producer = if let Some(url) = &self.announce_url { log::info!("forwarding announces to {}", url); // Establish a QUIC connection to the forward URL - let (session, _quic_client_initial_cid) = self - .quic + let (session, _quic_client_initial_cid) = self.quic_endpoints[0] .client - .connect(url) + .connect(url, None) .await .context("failed to establish forward connection")?; @@ -121,6 +135,7 @@ impl Relay { .context("failed to establish forward session")?; // Create a normal looking session, except we never forward or register announces. + let coordinator = self.coordinator.clone(); let session = Session { session, producer: Some(Producer::new( @@ -128,7 +143,12 @@ impl Relay { self.locals.clone(), remotes.clone(), )), - consumer: Some(Consumer::new(subscriber, self.locals.clone(), None, None)), + consumer: Some(Consumer::new( + subscriber, + self.locals.clone(), + coordinator, + None, + )), }; let forward_producer = session.producer.clone(); @@ -140,15 +160,46 @@ impl Relay { None }; - // Start the QUIC server loop - let mut server = self.quic.server.context("missing TLS certificate")?; - log::info!("listening on {}", server.local_addr()?); + let servers: Vec = self + .quic_endpoints + .into_iter() + .map(|endpoint| { + endpoint + .server + .context("missing TLS certificate for server") + }) + .collect::>()?; + + // This will hold the futures for all our listening servers. + let mut accepts: FuturesUnordered = FuturesUnordered::new(); + for mut server in servers { + log::info!("listening on {}", server.local_addr()?); + + // Create a future, box it, and push it to the collection. + accepts.push( + async move { + let conn = server.accept().await.context("accept failed"); + (conn, server) + } + .boxed(), + ); + } loop { tokio::select! { - // Accept a new QUIC connection - res = server.accept() => { - let (conn, connection_id) = res.context("failed to accept QUIC connection")?; + // This branch polls all the `accept` futures concurrently. + Some((conn_result, mut server)) = accepts.next() => { + // An accept operation has completed. + // First, immediately queue up the next accept() call for this server. + accepts.push( + async move { + let conn = server.accept().await.context("accept failed"); + (conn, server) + } + .boxed(), + ); + + let (conn, connection_id) = conn_result.context("failed to accept QUIC connection")?; // Construct mlog path from connection ID if mlog directory is configured let mlog_path = self.mlog_dir.as_ref() @@ -157,11 +208,10 @@ impl Relay { let locals = self.locals.clone(); let remotes = remotes.clone(); let forward = forward_producer.clone(); - let api = self.api.clone(); + let coordinator = self.coordinator.clone(); // Spawn a new task to handle the connection tasks.push(async move { - // Create the MoQ session over the connection (setup handshake etc) let (session, publisher, subscriber) = match moq_transport::session::Session::accept(conn, mlog_path).await { Ok(session) => session, @@ -172,10 +222,11 @@ impl Relay { }; // Create our MoQ relay session + let moq_session = session; let session = Session { - session, + session: moq_session, producer: publisher.map(|publisher| Producer::new(publisher, locals.clone(), remotes)), - consumer: subscriber.map(|subscriber| Consumer::new(subscriber, locals, api, forward)), + consumer: subscriber.map(|subscriber| Consumer::new(subscriber, locals, coordinator, forward)), }; if let Err(err) = session.run().await { diff --git a/moq-relay-ietf/src/remote.rs b/moq-relay-ietf/src/remote.rs index 1b412d0..8d2c645 100644 --- a/moq-relay-ietf/src/remote.rs +++ b/moq-relay-ietf/src/remote.rs @@ -1,412 +1,203 @@ use std::collections::HashMap; - -use std::collections::VecDeque; -use std::fmt; -use std::ops; use std::sync::Arc; -use std::sync::Weak; -use futures::stream::FuturesUnordered; -use futures::FutureExt; -use futures::StreamExt; use moq_native_ietf::quic; use moq_transport::coding::TrackNamespace; -use moq_transport::serve::{Track, TrackReader, TrackWriter}; -use moq_transport::watch::State; +use moq_transport::serve::{Track, TrackReader}; +use tokio::sync::Mutex; use url::Url; -use crate::Api; - -/// Information about remote origins. -pub struct Remotes { - /// The client we use to fetch/store origin information. - pub api: Api, - - // A QUIC endpoint we'll use to fetch from other origins. - pub quic: quic::Client, -} - -impl Remotes { - pub fn produce(self) -> (RemotesProducer, RemotesConsumer) { - let (send, recv) = State::default().split(); - let info = Arc::new(self); - - let producer = RemotesProducer::new(info.clone(), send); - let consumer = RemotesConsumer::new(info, recv); - - (producer, consumer) - } -} - -#[derive(Default)] -struct RemotesState { - lookup: HashMap, - requested: VecDeque, -} +use crate::Coordinator; -// Clone for convenience, but there should only be one instance of this +/// Manages connections to remote relays. +/// +/// When a subscription request comes in for a namespace that isn't local, +/// RemoteManager uses the coordinator to find which remote relay serves it, +/// establishes a connection if needed, and subscribes to the track. #[derive(Clone)] -pub struct RemotesProducer { - info: Arc, - state: State, -} - -impl RemotesProducer { - fn new(info: Arc, state: State) -> Self { - Self { info, state } - } - - /// Block until the next remote requested by a consumer. - async fn next(&mut self) -> Option { - loop { - { - let state = self.state.lock(); - if !state.requested.is_empty() { - return state.into_mut()?.requested.pop_front(); - } - - state.modified()? - } - .await; - } - } - - /// Run the remotes producer to serve remote requests. - pub async fn run(mut self) -> anyhow::Result<()> { - let mut tasks = FuturesUnordered::new(); - - loop { - tokio::select! { - Some(mut remote) = self.next() => { - let url = remote.url.clone(); - - // Spawn a task to serve the remote - tasks.push(async move { - let info = remote.info.clone(); - - log::warn!("serving remote: {:?}", info); - - // Run the remote producer - if let Err(err) = remote.run().await { - log::warn!("failed serving remote: {:?}, error: {}", info, err); - } - - url - }); - } - - // Handle finished remote producers - res = tasks.next(), if !tasks.is_empty() => { - let url = res.unwrap(); - - if let Some(mut state) = self.state.lock_mut() { - state.lookup.remove(&url); - } - }, - else => return Ok(()), - } - } +pub struct RemoteManager { + coordinator: Arc, + clients: Vec, + remotes: Arc>>, +} + +impl RemoteManager { + /// Create a new RemoteManager. + pub fn new( + coordinator: Arc, + clients: Vec, + ) -> anyhow::Result { + Ok(Self { + coordinator, + clients, + remotes: Arc::new(Mutex::new(HashMap::new())), + }) } -} -impl ops::Deref for RemotesProducer { - type Target = Remotes; + /// Subscribe to a track from a remote relay. + /// + /// This will: + /// 1. Use the coordinator to lookup which relay serves the namespace + /// 2. Connect to that relay if not already connected + /// 3. Subscribe to the specific track + /// + /// Returns None if the namespace isn't found in any remote relay. + pub async fn subscribe( + &self, + namespace: TrackNamespace, + track_name: String, + ) -> anyhow::Result> { + // Ask coordinator where this namespace lives + let (origin, client) = match self.coordinator.lookup(&namespace).await { + Ok((origin, client)) => (origin, client), + Err(_) => return Ok(None), // Namespace not found anywhere + }; - fn deref(&self) -> &Self::Target { - &self.info - } -} + let url = origin.url(); -#[derive(Clone)] -pub struct RemotesConsumer { - pub info: Arc, - state: State, -} + // Get or create a connection to the remote relay + let remote = self.get_or_connect(&url, client.as_ref()).await?; -impl RemotesConsumer { - fn new(info: Arc, state: State) -> Self { - Self { info, state } + // Subscribe to the track on the remote + remote.subscribe(namespace, track_name).await } - /// Route to a remote origin based on the namespace. - pub async fn route( + /// Get an existing remote connection or create a new one. + async fn get_or_connect( &self, - namespace: &TrackNamespace, - ) -> anyhow::Result> { - // Always fetch the origin instead of using the (potentially invalid) cache. - let origin = match self.api.get_origin(&namespace.to_utf8_path()).await? { - None => return Ok(None), - Some(origin) => origin, - }; - - // Check if we already have a remote for this origin - let state = self.state.lock(); - if let Some(remote) = state.lookup.get(&origin.url).cloned() { - return Ok(Some(remote)); + url: &Url, + client: Option<&quic::Client>, + ) -> anyhow::Result { + let mut remotes = self.remotes.lock().await; + + // Check if we already have a connection + if let Some(remote) = remotes.get(url) { + if remote.is_connected() { + return Ok(remote.clone()); + } + // Connection is dead, remove it + remotes.remove(url); } - // Create a new remote for this origin - let mut state = match state.into_mut() { - Some(state) => state, - None => return Ok(None), - }; - - let remote = Remote { - url: origin.url.clone(), - remotes: self.info.clone(), - }; + let client = client.unwrap_or(&self.clients[0]); - // Produce the remote - let (writer, reader) = remote.produce(); - state.requested.push_back(writer); + // Create a new connection with its own QUIC client + log::info!("connecting to remote relay: {}", url); + let remote = Remote::connect(url.clone(), client).await?; - // Insert the remote into our Map - state.lookup.insert(origin.url, reader.clone()); + remotes.insert(url.clone(), remote.clone()); - Ok(Some(reader)) + Ok(remote) } -} -impl ops::Deref for RemotesConsumer { - type Target = Remotes; - - fn deref(&self) -> &Self::Target { - &self.info + /// Remove a remote connection (called when connection fails). + pub async fn remove(&self, url: &Url) { + let mut remotes = self.remotes.lock().await; + remotes.remove(url); } } +/// A connection to a single remote relay with its own QUIC client. +#[derive(Clone)] pub struct Remote { - pub remotes: Arc, - pub url: Url, -} - -impl fmt::Debug for Remote { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Remote") - .field("url", &self.url.to_string()) - .finish() - } -} - -impl ops::Deref for Remote { - type Target = Remotes; - - fn deref(&self) -> &Self::Target { - &self.remotes - } + url: Url, + subscriber: moq_transport::session::Subscriber, + /// Track subscriptions - maps (namespace, track_name) to track reader + tracks: Arc>>, } impl Remote { - /// Create a new broadcast. - pub fn produce(self) -> (RemoteProducer, RemoteConsumer) { - let (send, recv) = State::default().split(); - let info = Arc::new(self); - - let consumer = RemoteConsumer::new(info.clone(), recv); - let producer = RemoteProducer::new(info, send); - - (producer, consumer) - } -} - -#[derive(Default)] -struct RemoteState { - tracks: HashMap<(TrackNamespace, String), RemoteTrackWeak>, - requested: VecDeque, -} - -pub struct RemoteProducer { - pub info: Arc, - state: State, -} - -impl RemoteProducer { - fn new(info: Arc, state: State) -> Self { - Self { info, state } - } - - pub async fn run(&mut self) -> anyhow::Result<()> { - // TODO reuse QUIC and MoQ sessions - let (session, _quic_client_initial_cid) = self.quic.connect(&self.url).await?; + /// Connect to a remote relay with a dedicated QUIC client. + async fn connect(url: Url, client: &quic::Client) -> anyhow::Result { + // Connect to the remote relay (DNS resolution happens inside connect) + let (session, _cid) = client.connect(&url, None).await?; let (session, subscriber) = moq_transport::session::Subscriber::connect(session).await?; - // Run the session - let mut session = session.run().boxed(); - let mut tasks = FuturesUnordered::new(); - - let mut done = None; - - // Serve requested tracks - loop { - tokio::select! { - track = self.next(), if done.is_none() => { - let track = match track { - Ok(Some(track)) => track, - Ok(None) => { done = Some(Ok(())); continue }, - Err(err) => { done = Some(Err(err)); continue }, - }; - - let info = track.info.clone(); - let mut subscriber = subscriber.clone(); - - tasks.push(async move { - if let Err(err) = subscriber.subscribe(track).await { - log::warn!("failed serving track: {:?}, error: {}", info, err); - } - }); - } - _ = tasks.next(), if !tasks.is_empty() => {}, - - // Keep running the session - res = &mut session, if !tasks.is_empty() || done.is_none() => return Ok(res?), - - else => return done.unwrap(), + // Spawn a task to run the session + let session_url = url.clone(); + tokio::spawn(async move { + if let Err(err) = session.run().await { + log::warn!("remote session closed: {} - {}", session_url, err); } - } - } - - /// Block until the next track requested by a consumer. - async fn next(&self) -> anyhow::Result> { - loop { - let notify = { - let state = self.state.lock(); - - // Check if we have any requested tracks - if !state.requested.is_empty() { - return Ok(state - .into_mut() - .and_then(|mut state| state.requested.pop_front())); - } - - match state.modified() { - Some(notified) => notified, - None => return Ok(None), - } - }; - - notify.await - } - } -} - -impl ops::Deref for RemoteProducer { - type Target = Remote; + }); - fn deref(&self) -> &Self::Target { - &self.info + Ok(Self { + url, + subscriber, + tracks: Arc::new(Mutex::new(HashMap::new())), + }) } -} -#[derive(Clone)] -pub struct RemoteConsumer { - pub info: Arc, - state: State, -} - -impl RemoteConsumer { - fn new(info: Arc, state: State) -> Self { - Self { info, state } + /// Check if the connection is still alive. + /// Note: This is a simple heuristic - we assume connected until proven otherwise. + fn is_connected(&self) -> bool { + // We don't have a direct way to check if the subscriber is closed, + // so we assume it's connected. Dead connections will be cleaned up + // when subscribe operations fail. + true } - /// Request a track from the broadcast. - pub fn subscribe( + /// Subscribe to a track on this remote relay. + pub async fn subscribe( &self, namespace: TrackNamespace, - name: String, - ) -> anyhow::Result> { - let key = (namespace.clone(), name.clone()); - let state = self.state.lock(); - if let Some(track) = state.tracks.get(&key) { - if let Some(track) = track.upgrade() { - return Ok(Some(track)); + track_name: String, + ) -> anyhow::Result> { + let key = (namespace.clone(), track_name.clone()); + + // Check if we already have this track + { + let tracks = self.tracks.lock().await; + if let Some(reader) = tracks.get(&key) { + return Ok(Some(reader.clone())); } } - let mut state = match state.into_mut() { - Some(state) => state, - None => return Ok(None), - }; - - let (writer, reader) = Track::new(namespace, name).produce(); - let reader = RemoteTrackReader::new(reader, self.state.clone()); - - // Insert the track into our Map so we deduplicate future requests. - state.tracks.insert(key, reader.downgrade()); - state.requested.push_back(writer); - - Ok(Some(reader)) - } -} - -impl ops::Deref for RemoteConsumer { - type Target = Remote; - - fn deref(&self) -> &Self::Target { - &self.info - } -} - -#[derive(Clone)] -pub struct RemoteTrackReader { - pub reader: TrackReader, - drop: Arc, -} + // Create a new track and subscribe + let (writer, reader) = Track::new(namespace.clone(), track_name.clone()).produce(); + + // Subscribe to the track on the remote + let mut subscriber = self.subscriber.clone(); + let track_key = key.clone(); + let tracks = self.tracks.clone(); + let url = self.url.clone(); + + tokio::spawn(async move { + log::info!( + "subscribing to remote track: {} - {}/{}", + url, + track_key.0, + track_key.1 + ); + + if let Err(err) = subscriber.subscribe(writer).await { + log::warn!( + "failed subscribing to remote track: {} - {}/{} - {}", + url, + track_key.0, + track_key.1, + err + ); + } -impl RemoteTrackReader { - fn new(reader: TrackReader, parent: State) -> Self { - let drop = Arc::new(RemoteTrackDrop { - parent, - key: (reader.namespace.clone(), reader.name.clone()), + // Remove track from map when subscription ends + tracks.lock().await.remove(&track_key); }); - Self { reader, drop } - } - - fn downgrade(&self) -> RemoteTrackWeak { - RemoteTrackWeak { - reader: self.reader.clone(), - drop: Arc::downgrade(&self.drop), + // Store the reader for deduplication + { + let mut tracks = self.tracks.lock().await; + tracks.insert(key, reader.clone()); } - } -} -impl ops::Deref for RemoteTrackReader { - type Target = TrackReader; - - fn deref(&self) -> &Self::Target { - &self.reader - } -} - -impl ops::DerefMut for RemoteTrackReader { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.reader - } -} - -struct RemoteTrackWeak { - reader: TrackReader, - drop: Weak, -} - -impl RemoteTrackWeak { - fn upgrade(&self) -> Option { - Some(RemoteTrackReader { - reader: self.reader.clone(), - drop: self.drop.upgrade()?, - }) + Ok(Some(reader)) } } -struct RemoteTrackDrop { - parent: State, - key: (TrackNamespace, String), -} - -impl Drop for RemoteTrackDrop { - fn drop(&mut self) { - if let Some(mut parent) = self.parent.lock_mut() { - parent.tracks.remove(&self.key); - } +impl std::fmt::Debug for Remote { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Remote") + .field("url", &self.url.to_string()) + .finish() } } diff --git a/moq-sub/src/main.rs b/moq-sub/src/main.rs index e38154f..2663833 100644 --- a/moq-sub/src/main.rs +++ b/moq-sub/src/main.rs @@ -22,13 +22,9 @@ async fn main() -> anyhow::Result<()> { let config = Config::parse(); let tls = config.tls.load()?; - let quic = quic::Endpoint::new(quic::Config { - bind: config.bind, - qlog_dir: None, - tls, - })?; + let quic = quic::Endpoint::new(quic::Config::new(config.bind, None, tls))?; - let (session, connection_id) = quic.client.connect(&config.url).await?; + let (session, connection_id) = quic.client.connect(&config.url, None).await?; log::info!( "connected with CID: {} (use this to look up qlog/mlog on server)", diff --git a/moq-transport/src/coding/track_namespace.rs b/moq-transport/src/coding/track_namespace.rs index 338fa5a..4179038 100644 --- a/moq-transport/src/coding/track_namespace.rs +++ b/moq-transport/src/coding/track_namespace.rs @@ -1,6 +1,18 @@ use super::{Decode, DecodeError, Encode, EncodeError, TupleField}; use core::hash::{Hash, Hasher}; +use std::convert::TryFrom; use std::fmt; +use thiserror::Error; + +/// Error type for TrackNamespace conversion failures +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum TrackNamespaceError { + #[error("too many fields: {0} exceeds maximum of {1}")] + TooManyFields(usize, usize), + + #[error("field too large: {0} bytes exceeds maximum of {1}")] + FieldTooLarge(usize, usize), +} /// TrackNamespace #[derive(Clone, Default, Eq, PartialEq)] @@ -92,11 +104,69 @@ impl fmt::Display for TrackNamespace { } } +impl TryFrom> for TrackNamespace { + type Error = TrackNamespaceError; + + fn try_from(fields: Vec) -> Result { + if fields.len() > Self::MAX_FIELDS { + return Err(TrackNamespaceError::TooManyFields( + fields.len(), + Self::MAX_FIELDS, + )); + } + for field in &fields { + if field.value.len() > TupleField::MAX_VALUE_SIZE { + return Err(TrackNamespaceError::FieldTooLarge( + field.value.len(), + TupleField::MAX_VALUE_SIZE, + )); + } + } + Ok(Self { fields }) + } +} + +impl TryFrom<&str> for TrackNamespace { + type Error = TrackNamespaceError; + + fn try_from(path: &str) -> Result { + let fields: Vec = path.split('/').map(TupleField::from_utf8).collect(); + Self::try_from(fields) + } +} + +impl TryFrom for TrackNamespace { + type Error = TrackNamespaceError; + + fn try_from(path: String) -> Result { + Self::try_from(path.as_str()) + } +} + +impl TryFrom> for TrackNamespace { + type Error = TrackNamespaceError; + + fn try_from(parts: Vec<&str>) -> Result { + let fields: Vec = parts.into_iter().map(TupleField::from_utf8).collect(); + Self::try_from(fields) + } +} + +impl TryFrom> for TrackNamespace { + type Error = TrackNamespaceError; + + fn try_from(parts: Vec) -> Result { + let fields: Vec = parts.iter().map(|s| TupleField::from_utf8(s)).collect(); + Self::try_from(fields) + } +} + #[cfg(test)] mod tests { use super::*; use bytes::Bytes; use bytes::BytesMut; + use std::convert::TryInto; #[test] fn encode_decode() { @@ -165,4 +235,69 @@ mod tests { DecodeError::FieldBoundsExceeded(_) )); } + + #[test] + fn try_from_str() { + let ns: TrackNamespace = "test/path/to/resource".try_into().unwrap(); + assert_eq!(ns.fields.len(), 4); + assert_eq!(ns.to_utf8_path(), "/test/path/to/resource"); + } + + #[test] + fn try_from_string() { + let path = String::from("test/path"); + let ns: TrackNamespace = path.try_into().unwrap(); + assert_eq!(ns.fields.len(), 2); + assert_eq!(ns.to_utf8_path(), "/test/path"); + } + + #[test] + fn try_from_vec_str() { + let parts = vec!["test", "path", "to", "resource"]; + let ns: TrackNamespace = parts.try_into().unwrap(); + assert_eq!(ns.fields.len(), 4); + assert_eq!(ns.to_utf8_path(), "/test/path/to/resource"); + } + + #[test] + fn try_from_vec_string() { + let parts = vec![String::from("test"), String::from("path")]; + let ns: TrackNamespace = parts.try_into().unwrap(); + assert_eq!(ns.fields.len(), 2); + assert_eq!(ns.to_utf8_path(), "/test/path"); + } + + #[test] + fn try_from_vec_tuple_field() { + let fields = vec![TupleField::from_utf8("test"), TupleField::from_utf8("path")]; + let ns: TrackNamespace = fields.try_into().unwrap(); + assert_eq!(ns.fields.len(), 2); + assert_eq!(ns.to_utf8_path(), "/test/path"); + } + + #[test] + fn try_from_too_many_fields() { + let mut fields = Vec::new(); + for i in 0..TrackNamespace::MAX_FIELDS + 1 { + fields.push(TupleField::from_utf8(&format!("field{}", i))); + } + let result: Result = fields.try_into(); + assert!(matches!( + result.unwrap_err(), + TrackNamespaceError::TooManyFields(33, 32) + )); + } + + #[test] + fn try_from_field_too_large() { + let large_value = "x".repeat(TupleField::MAX_VALUE_SIZE + 1); + let fields = vec![TupleField { + value: large_value.into_bytes(), + }]; + let result: Result = fields.try_into(); + assert!(matches!( + result.unwrap_err(), + TrackNamespaceError::FieldTooLarge(4097, 4096) + )); + } }