From b3eca9ce8321834ae09372145e9e228a8eeaeea5 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Wed, 31 Jul 2024 17:12:45 +0200 Subject: [PATCH] feat: add support for dynamic rate limit configurations with hot reload - Introduce dynamic rate limiting configurations - Incorporate hot reload functionality to update rate limits without a server restart - Allow configuration of hot reload of rate limits in `gateway` server setup --- Cargo.lock | 1 + .../federated-dev/src/dev/gateway_nanny.rs | 7 +- .../src/federation/builder/test_runtime.rs | 4 +- .../src/rate_limiting/in_memory/key_based.rs | 83 ++++++++---- .../runtime-local/src/rate_limiting/redis.rs | 27 ++-- engine/crates/runtime/src/rate_limiting.rs | 4 +- gateway/crates/federated-server/Cargo.toml | 2 + gateway/crates/federated-server/src/config.rs | 27 +++- .../federated-server/src/config/hot_reload.rs | 124 ++++++++++++++++++ .../federated-server/src/config/rate_limit.rs | 2 +- gateway/crates/federated-server/src/error.rs | 12 +- gateway/crates/federated-server/src/lib.rs | 2 +- gateway/crates/federated-server/src/server.rs | 32 ++++- .../federated-server/src/server/gateway.rs | 39 +++++- gateway/crates/gateway-binary/src/args.rs | 6 +- .../crates/gateway-binary/src/args/lambda.rs | 14 +- gateway/crates/gateway-binary/src/args/std.rs | 17 ++- gateway/crates/gateway-binary/src/main.rs | 15 ++- 18 files changed, 346 insertions(+), 72 deletions(-) create mode 100644 gateway/crates/federated-server/src/config/hot_reload.rs diff --git a/Cargo.lock b/Cargo.lock index 40eed67969..7c981e7ce6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2873,6 +2873,7 @@ dependencies = [ "indoc", "insta", "lambda_http", + "notify", "parser-sdl", "regex", "reqwest", diff --git a/cli/crates/federated-dev/src/dev/gateway_nanny.rs b/cli/crates/federated-dev/src/dev/gateway_nanny.rs index 3aca3aaf29..1f3e3ad0a4 100644 --- a/cli/crates/federated-dev/src/dev/gateway_nanny.rs +++ b/cli/crates/federated-dev/src/dev/gateway_nanny.rs @@ -9,6 +9,7 @@ use futures_concurrency::stream::Merge; use futures_util::{future::BoxFuture, stream::BoxStream, FutureExt as _, StreamExt}; use runtime::rate_limiting::KeyedRateLimitConfig; use runtime_local::rate_limiting::in_memory::key_based::InMemoryRateLimiter; +use tokio::sync::mpsc; use tokio_stream::wrappers::WatchStream; /// The GatewayNanny looks after the `Gateway` - on updates to the graph or config it'll @@ -56,7 +57,7 @@ pub(super) async fn new_gateway(config: Option) -> O .into_iter() .map(|(k, v)| { ( - k, + k.to_string(), runtime::rate_limiting::GraphRateLimit { limit: v.limit, duration: v.duration, @@ -65,6 +66,8 @@ pub(super) async fn new_gateway(config: Option) -> O }) .collect::>(); + let (_, rx) = mpsc::channel(100); + let runtime = CliRuntime { fetcher: runtime_local::NativeFetcher::runtime_fetcher(), trusted_documents: runtime::trusted_documents_client::Client::new( @@ -72,7 +75,7 @@ pub(super) async fn new_gateway(config: Option) -> O ), kv: runtime_local::InMemoryKvStore::runtime(), meter: grafbase_telemetry::metrics::meter_from_global_provider(), - rate_limiter: InMemoryRateLimiter::runtime(KeyedRateLimitConfig { rate_limiting_configs }), + rate_limiter: InMemoryRateLimiter::runtime(KeyedRateLimitConfig { rate_limiting_configs }, rx), }; let schema = config.try_into().ok()?; diff --git a/engine/crates/integration-tests/src/federation/builder/test_runtime.rs b/engine/crates/integration-tests/src/federation/builder/test_runtime.rs index 4cc5d792a6..9d546cc57b 100644 --- a/engine/crates/integration-tests/src/federation/builder/test_runtime.rs +++ b/engine/crates/integration-tests/src/federation/builder/test_runtime.rs @@ -4,6 +4,7 @@ use runtime_local::{ rate_limiting::in_memory::key_based::InMemoryRateLimiter, InMemoryHotCacheFactory, InMemoryKvStore, NativeFetcher, }; use runtime_noop::trusted_documents::NoopTrustedDocuments; +use tokio::sync::mpsc; pub struct TestRuntime { pub fetcher: runtime::fetch::Fetcher, @@ -16,13 +17,14 @@ pub struct TestRuntime { impl Default for TestRuntime { fn default() -> Self { + let (_, rx) = mpsc::channel(100); Self { fetcher: NativeFetcher::runtime_fetcher(), trusted_documents: trusted_documents_client::Client::new(NoopTrustedDocuments), kv: InMemoryKvStore::runtime(), meter: metrics::meter_from_global_provider(), hooks: Default::default(), - rate_limiter: InMemoryRateLimiter::runtime(Default::default()), + rate_limiter: InMemoryRateLimiter::runtime(Default::default(), rx), } } } diff --git a/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs b/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs index 4685278487..27006fddcb 100644 --- a/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs +++ b/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs @@ -1,14 +1,17 @@ -use std::collections::HashMap; use std::net::IpAddr; use std::num::NonZeroU32; +use std::sync::Arc; +use std::{collections::HashMap, sync::RwLock}; use futures_util::future::BoxFuture; use futures_util::FutureExt; use governor::Quota; +use grafbase_telemetry::span::GRAFBASE_TARGET; use serde_json::Value; use http::{HeaderName, HeaderValue}; use runtime::rate_limiting::{Error, GraphRateLimit, KeyedRateLimitConfig, RateLimiter, RateLimiterContext}; +use tokio::sync::mpsc; pub struct RateLimitingContext(pub String); @@ -34,48 +37,72 @@ impl RateLimiterContext for RateLimitingContext { } } -#[derive(Default)] pub struct InMemoryRateLimiter { - inner: HashMap>, + limiters: Arc>>>, } impl InMemoryRateLimiter { - pub fn runtime(config: KeyedRateLimitConfig<'_>) -> RateLimiter { - let mut limiter = Self::default(); + pub fn runtime( + config: KeyedRateLimitConfig, + mut updates: mpsc::Receiver>, + ) -> RateLimiter { + let mut limiters = HashMap::new(); // add subgraph rate limiting configuration - for (name, rate_limit_config) in config.rate_limiting_configs { - limiter = limiter.with_rate_limiter(name, rate_limit_config); + for (name, config) in config.rate_limiting_configs { + let Some(limiter) = create_limiter(config) else { + continue; + }; + + limiters.insert(name.to_string(), limiter); } - RateLimiter::new(limiter) - } + let limiters = Arc::new(RwLock::new(limiters)); + let limiters_copy = limiters.clone(); + + tokio::spawn(async move { + while let Some(updates) = updates.recv().await { + let mut limiters = limiters_copy.write().unwrap(); + + for (name, config) in updates { + let Some(limiter) = create_limiter(config) else { + continue; + }; + + limiters.insert(name.to_string(), limiter); + } + } + }); - pub fn with_rate_limiter(mut self, key: &str, rate_limit_config: GraphRateLimit) -> Self { - let quota = (rate_limit_config.limit as u64) - .checked_div(rate_limit_config.duration.as_secs()) - .expect("rate limiter with invalid per second quota"); - - self.inner.insert( - key.to_string(), - governor::RateLimiter::keyed(Quota::per_second( - NonZeroU32::new(quota as u32).expect("rate limit duration cannot be 0"), - )), - ); - self + RateLimiter::new(Self { limiters }) } } +fn create_limiter(rate_limit_config: GraphRateLimit) -> Option> { + let Some(quota) = (rate_limit_config.limit as u64).checked_div(rate_limit_config.duration.as_secs()) else { + tracing::error!(target: GRAFBASE_TARGET, "the duration for rate limit cannot be zero"); + return None; + }; + + let Some(quota) = NonZeroU32::new(quota as u32) else { + tracing::error!(target: GRAFBASE_TARGET, "the limit is too low per defined duration"); + return None; + }; + + Some(governor::RateLimiter::keyed(Quota::per_second(quota))) +} + impl runtime::rate_limiting::RateLimiterInner for InMemoryRateLimiter { fn limit<'a>(&'a self, context: &'a dyn RateLimiterContext) -> BoxFuture<'a, Result<(), Error>> { async { - if let Some(key) = context.key() { - if let Some(rate_limiter) = self.inner.get(key) { - rate_limiter - .check_key(&usize::MIN) - .map_err(|_err| Error::ExceededCapacity)?; - }; - } + let Some(key) = context.key() else { return Ok(()) }; + let limiters = self.limiters.read().unwrap(); + + if let Some(rate_limiter) = limiters.get(key) { + rate_limiter + .check_key(&usize::MIN) + .map_err(|_err| Error::ExceededCapacity)?; + }; Ok(()) } diff --git a/engine/crates/runtime-local/src/rate_limiting/redis.rs b/engine/crates/runtime-local/src/rate_limiting/redis.rs index 330c4da9df..99fdf0e1f2 100644 --- a/engine/crates/runtime-local/src/rate_limiting/redis.rs +++ b/engine/crates/runtime-local/src/rate_limiting/redis.rs @@ -15,6 +15,7 @@ use futures_util::future::BoxFuture; use grafbase_telemetry::span::GRAFBASE_TARGET; use redis::ClientTlsConfig; use runtime::rate_limiting::{Error, GraphRateLimit, RateLimiter, RateLimiterContext}; +use tokio::sync::watch; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct RateLimitRedisConfig<'a> { @@ -30,6 +31,8 @@ pub struct RateLimitRedisTlsConfig<'a> { pub ca: Option<&'a Path>, } +pub type Limits = watch::Receiver>; + /// Rate limiter by utilizing Redis as a backend. It uses a averaging fixed window algorithm /// to define is the limit reached or not. /// @@ -47,7 +50,7 @@ pub struct RateLimitRedisTlsConfig<'a> { pub struct RedisRateLimiter { pool: Pool, key_prefix: String, - subgraph_limits: HashMap, + limits: Limits, } enum Key<'a> { @@ -69,18 +72,11 @@ impl<'a> fmt::Display for Key<'a> { } impl RedisRateLimiter { - pub async fn runtime( - config: RateLimitRedisConfig<'_>, - subgraph_limits: impl IntoIterator, - ) -> anyhow::Result { - let inner = Self::new(config, subgraph_limits).await?; - Ok(RateLimiter::new(inner)) + pub async fn runtime(config: RateLimitRedisConfig<'_>, limits: Limits) -> anyhow::Result { + Ok(RateLimiter::new(Self::new(config, limits).await?)) } - pub async fn new( - config: RateLimitRedisConfig<'_>, - subgraph_limits: impl IntoIterator, - ) -> anyhow::Result { + pub async fn new(config: RateLimitRedisConfig<'_>, limits: Limits) -> anyhow::Result { let tls_config = match config.tls { Some(tls) => { let client_tls = match tls.cert.zip(tls.key) { @@ -144,15 +140,10 @@ impl RedisRateLimiter { } }; - let subgraph_limits = subgraph_limits - .into_iter() - .map(|(key, value)| (key.to_string(), value)) - .collect(); - Ok(Self { pool, key_prefix: config.key_prefix.to_string(), - subgraph_limits, + limits, }) } @@ -167,7 +158,7 @@ impl RedisRateLimiter { async fn limit_inner(&self, context: &dyn RateLimiterContext) -> Result<(), Error> { let Some(key) = context.key() else { return Ok(()) }; - let Some(config) = self.subgraph_limits.get(key) else { + let Some(config) = self.limits.borrow().get(key).copied() else { return Ok(()); }; diff --git a/engine/crates/runtime/src/rate_limiting.rs b/engine/crates/runtime/src/rate_limiting.rs index 08185f3d45..710852165c 100644 --- a/engine/crates/runtime/src/rate_limiting.rs +++ b/engine/crates/runtime/src/rate_limiting.rs @@ -59,6 +59,6 @@ pub struct GraphRateLimit { } #[derive(Debug, Clone, Default)] -pub struct KeyedRateLimitConfig<'a> { - pub rate_limiting_configs: HashMap<&'a str, GraphRateLimit>, +pub struct KeyedRateLimitConfig { + pub rate_limiting_configs: HashMap, } diff --git a/gateway/crates/federated-server/Cargo.toml b/gateway/crates/federated-server/Cargo.toml index 524976c5c7..9029e82402 100644 --- a/gateway/crates/federated-server/Cargo.toml +++ b/gateway/crates/federated-server/Cargo.toml @@ -54,6 +54,8 @@ axum-aws-lambda = { version = "0.7.0", optional = true } tower = { workspace = true, optional = true } lambda_http = { version = "0.11.1", optional = true } serde_regex = "1.1.0" +notify = "6.1.1" +toml = "0.8.12" [dev-dependencies] indoc = "2.0.5" diff --git a/gateway/crates/federated-server/src/config.rs b/gateway/crates/federated-server/src/config.rs index 0caddb6287..c110c2f810 100644 --- a/gateway/crates/federated-server/src/config.rs +++ b/gateway/crates/federated-server/src/config.rs @@ -3,9 +3,15 @@ mod cors; mod entity_caching; mod header; mod health; +pub(crate) mod hot_reload; mod rate_limit; -use std::{collections::BTreeMap, net::SocketAddr, path::PathBuf, time::Duration}; +use std::{ + collections::{BTreeMap, HashMap}, + net::SocketAddr, + path::PathBuf, + time::Duration, +}; pub use self::health::HealthConfig; use ascii::AsciiString; @@ -67,6 +73,25 @@ pub struct Config { pub entity_caching: EntityCachingConfig, } +impl Config { + /// Load the rate limit configuration for global and subgraph level settings. + pub fn as_keyed_rate_limit_config(&self) -> HashMap<&str, GraphRateLimit> { + let mut key_based_config = HashMap::new(); + + if let Some(global_config) = self.gateway.rate_limit.as_ref().and_then(|c| c.global) { + key_based_config.insert("global", global_config); + } + + for (subgraph_name, subgraph) in self.subgraphs.iter() { + if let Some(limit) = subgraph.rate_limit { + key_based_config.insert(subgraph_name, limit); + } + } + + key_based_config + } +} + #[derive(Debug, Default, serde::Deserialize)] #[serde(deny_unknown_fields)] pub struct GatewayConfig { diff --git a/gateway/crates/federated-server/src/config/hot_reload.rs b/gateway/crates/federated-server/src/config/hot_reload.rs new file mode 100644 index 0000000000..6c9fa3c5a1 --- /dev/null +++ b/gateway/crates/federated-server/src/config/hot_reload.rs @@ -0,0 +1,124 @@ +use std::{collections::HashMap, fs, path::PathBuf, sync::OnceLock, time::Duration}; + +use grafbase_telemetry::span::GRAFBASE_TARGET; +use notify::{EventHandler, EventKind, PollWatcher, Watcher}; +use runtime::rate_limiting::GraphRateLimit; +use tokio::sync::{mpsc, watch}; + +use crate::Config; + +type RateLimitData = HashMap; + +pub(crate) enum RateLimitSender { + Watch(watch::Sender), + Mpsc(mpsc::Sender), +} + +impl RateLimitSender { + fn send(&self, data: RateLimitData) -> crate::Result<()> { + match self { + RateLimitSender::Watch(channel) => Ok(channel.send(data)?), + RateLimitSender::Mpsc(channel) => Ok(channel.blocking_send(data)?), + } + } +} + +impl From> for RateLimitSender { + fn from(value: watch::Sender) -> Self { + Self::Watch(value) + } +} + +impl From> for RateLimitSender { + fn from(value: mpsc::Sender) -> Self { + Self::Mpsc(value) + } +} + +pub(crate) struct ConfigWatcher { + config_path: PathBuf, + rate_limit_sender: RateLimitSender, +} + +impl ConfigWatcher { + pub fn new(config_path: PathBuf, rate_limit_sender: impl Into) -> Self { + Self { + config_path, + rate_limit_sender: rate_limit_sender.into(), + } + } + + pub fn watch(self) -> crate::Result<()> { + static WATCHER: OnceLock = OnceLock::new(); + + WATCHER.get_or_init(|| { + let config = notify::Config::default().with_poll_interval(Duration::from_secs(1)); + let path = self.config_path.clone(); + let mut watcher = PollWatcher::new(self, config).expect("config watch init failed"); + + watcher + .watch(&path, notify::RecursiveMode::NonRecursive) + .expect("config watch failed"); + + watcher + }); + + Ok(()) + } + + fn reload_config(&self) -> crate::Result<()> { + let config = match fs::read_to_string(&self.config_path) { + Ok(config) => config, + Err(e) => { + tracing::error!(target: GRAFBASE_TARGET, "error reading gateway config: {e}"); + + return Ok(()); + } + }; + + let config: Config = match toml::from_str(&config) { + Ok(config) => config, + Err(e) => { + tracing::error!(target: GRAFBASE_TARGET, "error parsing gateway config: {e}"); + + return Ok(()); + } + }; + + let rate_limiting_configs = config + .as_keyed_rate_limit_config() + .into_iter() + .map(|(k, v)| { + ( + k.to_string(), + runtime::rate_limiting::GraphRateLimit { + limit: v.limit, + duration: v.duration, + }, + ) + }) + .collect(); + + self.rate_limit_sender.send(rate_limiting_configs)?; + + Ok(()) + } +} + +impl EventHandler for ConfigWatcher { + fn handle_event(&mut self, event: notify::Result) { + match event.map(|e| e.kind) { + Ok(EventKind::Any | EventKind::Create(_) | EventKind::Modify(_) | EventKind::Other) => { + tracing::debug!(target: GRAFBASE_TARGET, "reloading configuration file"); + + if let Err(e) = self.reload_config() { + tracing::error!(target: GRAFBASE_TARGET, "error reloading gateway config: {e}"); + }; + } + Ok(_) => (), + Err(e) => { + tracing::error!(target: GRAFBASE_TARGET, "error reading gateway config: {e}"); + } + } + } +} diff --git a/gateway/crates/federated-server/src/config/rate_limit.rs b/gateway/crates/federated-server/src/config/rate_limit.rs index 0986c8bb48..7ba7b4dd70 100644 --- a/gateway/crates/federated-server/src/config/rate_limit.rs +++ b/gateway/crates/federated-server/src/config/rate_limit.rs @@ -4,7 +4,7 @@ use serde::Deserializer; use std::path::PathBuf; use std::time::Duration; -#[derive(Debug, Clone, serde::Deserialize)] +#[derive(Debug, Clone, Copy, serde::Deserialize)] #[serde(deny_unknown_fields)] pub struct GraphRateLimit { pub limit: usize, diff --git a/gateway/crates/federated-server/src/error.rs b/gateway/crates/federated-server/src/error.rs index 3c67489581..3467a0e3ce 100644 --- a/gateway/crates/federated-server/src/error.rs +++ b/gateway/crates/federated-server/src/error.rs @@ -1,4 +1,4 @@ -use tokio::sync::watch::error::SendError; +use tokio::sync::{mpsc, watch}; /// The Grafbase gateway error type #[derive(Debug, thiserror::Error)] @@ -17,8 +17,14 @@ pub enum Error { Server(#[source] std::io::Error), } -impl From> for Error { - fn from(value: SendError) -> Self { +impl From> for Error { + fn from(value: watch::error::SendError) -> Self { + Self::InternalError(value.to_string()) + } +} + +impl From> for Error { + fn from(value: mpsc::error::SendError) -> Self { Self::InternalError(value.to_string()) } } diff --git a/gateway/crates/federated-server/src/lib.rs b/gateway/crates/federated-server/src/lib.rs index 1c7f37b18b..91cc5dc89c 100644 --- a/gateway/crates/federated-server/src/lib.rs +++ b/gateway/crates/federated-server/src/lib.rs @@ -16,4 +16,4 @@ mod server; /// The crate result type. pub type Result = std::result::Result; -pub use server::serve; +pub use server::{serve, ServerConfig}; diff --git a/gateway/crates/federated-server/src/server.rs b/gateway/crates/federated-server/src/server.rs index 2181376171..2d9c170984 100644 --- a/gateway/crates/federated-server/src/server.rs +++ b/gateway/crates/federated-server/src/server.rs @@ -25,6 +25,7 @@ use grafbase_telemetry::span::GRAFBASE_TARGET; use state::ServerState; use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, + path::PathBuf, time::Duration, }; use tokio::sync::mpsc; @@ -34,14 +35,35 @@ use self::gateway::GatewayConfig; const DEFAULT_LISTEN_ADDRESS: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5000); +/// Start parameter for the gateway. +pub struct ServerConfig { + /// The GraphQL endpoint listen address. + pub listen_addr: Option, + /// The gateway configuration. + pub config: Config, + /// The config file path for hot reload. + pub config_path: Option, + /// If true, watches changes to the config + /// and reloads _some_ of the things. + pub config_hot_reload: bool, + /// The way of loading the graph for the gateway. + pub fetch_method: GraphFetchMethod, + /// The opentelemetry tracer. + pub otel_tracing: Option, +} + /// Starts the self-hosted Grafbase gateway. If started with a schema path, will /// not connect our API for changes in the schema and if started without, we poll /// the schema registry every ten second for changes. pub async fn serve( - listen_addr: Option, - config: Config, - fetch_method: GraphFetchMethod, - otel_tracing: Option, + ServerConfig { + listen_addr, + config, + config_path, + fetch_method, + otel_tracing, + config_hot_reload, + }: ServerConfig, ) -> crate::Result<()> { let path = config.graph.path.as_deref().unwrap_or("/graphql"); @@ -74,6 +96,8 @@ pub async fn serve( rate_limit: config.gateway.rate_limit, timeout: config.gateway.timeout, entity_caching: config.entity_caching, + config_hot_reload, + config_path, }, otel_reload, sender, diff --git a/gateway/crates/federated-server/src/server/gateway.rs b/gateway/crates/federated-server/src/server/gateway.rs index 0a47fdbe24..8946ca92d4 100644 --- a/gateway/crates/federated-server/src/server/gateway.rs +++ b/gateway/crates/federated-server/src/server/gateway.rs @@ -1,10 +1,11 @@ use std::collections::HashMap; +use std::path::PathBuf; use std::time::Duration; use std::{collections::BTreeMap, sync::Arc}; use runtime_local::rate_limiting::in_memory::key_based::InMemoryRateLimiter; use runtime_local::rate_limiting::redis::RedisRateLimiter; -use tokio::sync::watch; +use tokio::sync::{mpsc, watch}; use engine_v2::Engine; use graphql_composition::FederatedGraph; @@ -13,9 +14,11 @@ use runtime::rate_limiting::KeyedRateLimitConfig; use runtime_local::{ComponentLoader, HooksWasi, HooksWasiConfig, InMemoryKvStore}; use runtime_noop::trusted_documents::NoopTrustedDocuments; -use crate::config::EntityCachingConfig; use crate::{ - config::{AuthenticationConfig, OperationLimitsConfig, RateLimitConfig, SubgraphConfig, TrustedDocumentsConfig}, + config::{ + hot_reload, AuthenticationConfig, EntityCachingConfig, OperationLimitsConfig, RateLimitConfig, SubgraphConfig, + TrustedDocumentsConfig, + }, HeaderRule, }; @@ -40,6 +43,8 @@ pub(crate) struct GatewayConfig { pub rate_limit: Option, pub timeout: Option, pub entity_caching: EntityCachingConfig, + pub config_hot_reload: bool, + pub config_path: Option, } /// Creates a new gateway from federated schema. @@ -59,6 +64,8 @@ pub(super) async fn generate( rate_limit, timeout, entity_caching, + config_hot_reload, + config_path, } = config; let schema_version = blake3::hash(federated_schema.as_bytes()); @@ -135,7 +142,7 @@ pub(super) async fn generate( .into_iter() .map(|(k, v)| { ( - k, + k.to_string(), runtime::rate_limiting::GraphRateLimit { limit: v.limit, duration: v.duration, @@ -146,6 +153,8 @@ pub(super) async fn generate( let rate_limiter = match config.rate_limit_config() { Some(config) if config.storage.is_redis() => { + let (tx, rx) = watch::channel(rate_limiting_configs); + let tls = config .redis .tls @@ -161,11 +170,29 @@ pub(super) async fn generate( tls, }; - RedisRateLimiter::runtime(global_config, rate_limiting_configs) + match config_path { + Some(path) if config_hot_reload => { + hot_reload::ConfigWatcher::new(path, tx).watch()?; + } + _ => (), + } + + RedisRateLimiter::runtime(global_config, rx) .await .map_err(|e| crate::Error::InternalError(e.to_string()))? } - _ => InMemoryRateLimiter::runtime(KeyedRateLimitConfig { rate_limiting_configs }), + _ => { + let (tx, rx) = mpsc::channel(100); + + match config_path { + Some(path) if config_hot_reload => { + hot_reload::ConfigWatcher::new(path, tx).watch()?; + } + _ => (), + } + + InMemoryRateLimiter::runtime(KeyedRateLimitConfig { rate_limiting_configs }, rx) + } }; let runtime = GatewayRuntime { diff --git a/gateway/crates/gateway-binary/src/args.rs b/gateway/crates/gateway-binary/src/args.rs index 4f7923d3e2..c52a1593ff 100644 --- a/gateway/crates/gateway-binary/src/args.rs +++ b/gateway/crates/gateway-binary/src/args.rs @@ -2,7 +2,7 @@ mod lambda; mod log; mod std; -use ::std::net::SocketAddr; +use ::std::{net::SocketAddr, path::Path}; use clap::Parser; use federated_server::{Config, GraphFetchMethod}; @@ -20,6 +20,10 @@ pub(crate) trait Args { fn config(&self) -> anyhow::Result; + fn config_path(&self) -> Option<&Path>; + + fn hot_reload(&self) -> bool; + fn log_format(&self) -> BoxedLayer where S: Subscriber + for<'span> LookupSpan<'span> + Send + Sync; diff --git a/gateway/crates/gateway-binary/src/args/lambda.rs b/gateway/crates/gateway-binary/src/args/lambda.rs index 224617f909..f790527f0c 100644 --- a/gateway/crates/gateway-binary/src/args/lambda.rs +++ b/gateway/crates/gateway-binary/src/args/lambda.rs @@ -1,4 +1,8 @@ -use std::{fs, io::ErrorKind, path::PathBuf}; +use std::{ + fs, + io::ErrorKind, + path::{Path, PathBuf}, +}; use anyhow::Context; use clap::Parser; @@ -49,6 +53,10 @@ impl super::Args for Args { } } + fn config_path(&self) -> Option<&Path> { + Some(&self.config) + } + fn log_format(&self) -> BoxedLayer where S: Subscriber + for<'span> LookupSpan<'span> + Send + Sync, @@ -64,6 +72,10 @@ impl super::Args for Args { } } + fn hot_reload(&self) -> bool { + false + } + fn listen_address(&self) -> Option { None } diff --git a/gateway/crates/gateway-binary/src/args/std.rs b/gateway/crates/gateway-binary/src/args/std.rs index 4f39f41fa6..d7c6779da7 100644 --- a/gateway/crates/gateway-binary/src/args/std.rs +++ b/gateway/crates/gateway-binary/src/args/std.rs @@ -1,4 +1,8 @@ -use std::{fs, net::SocketAddr, path::PathBuf}; +use std::{ + fs, + net::SocketAddr, + path::{Path, PathBuf}, +}; use anyhow::Context; use ascii::AsciiString; @@ -52,6 +56,9 @@ pub struct Args { /// Set the style of log output #[arg(long, env = "GRAFBASE_LOG_STYLE", default_value_t = LogStyle::Text)] log_style: LogStyle, + /// If set, parts of the configuration will get reloaded when changed. + #[arg(long, action)] + hot_reload: bool, } impl super::Args for Args { @@ -78,6 +85,14 @@ impl super::Args for Args { } } + fn config_path(&self) -> Option<&Path> { + self.config.as_deref() + } + + fn hot_reload(&self) -> bool { + self.hot_reload + } + fn config(&self) -> anyhow::Result { let mut config = match self.config.as_ref() { Some(path) => { diff --git a/gateway/crates/gateway-binary/src/main.rs b/gateway/crates/gateway-binary/src/main.rs index 6410185878..86908e3e61 100644 --- a/gateway/crates/gateway-binary/src/main.rs +++ b/gateway/crates/gateway-binary/src/main.rs @@ -11,7 +11,7 @@ use tracing::{error, Subscriber}; use tracing_subscriber::registry::LookupSpan; use tracing_subscriber::{reload, EnvFilter, Layer, Registry}; -use federated_server::{Config, GraphFetchMethod, OtelReload, OtelTracing}; +use federated_server::{Config, GraphFetchMethod, OtelReload, OtelTracing, ServerConfig}; use grafbase_telemetry::config::TelemetryConfig; use grafbase_telemetry::error::TracingError; use grafbase_telemetry::otel::layer::BoxedLayer; @@ -42,6 +42,7 @@ fn main() -> anyhow::Result<()> { runtime.block_on(async move { let otel_tracing = if std::env::var("__GRAFBASE_RUST_LOG").is_ok() { let filter = tracing_subscriber::filter::EnvFilter::try_from_env("__GRAFBASE_RUST_LOG").unwrap_or_default(); + tracing_subscriber::fmt() .pretty() .with_env_filter(filter) @@ -50,7 +51,9 @@ fn main() -> anyhow::Result<()> { .with_target(true) .without_time() .init(); + tracing::warn!("Skipping OTEL configuration."); + None } else { setup_tracing(&mut config, &args)? @@ -59,7 +62,15 @@ fn main() -> anyhow::Result<()> { let crate_version = crate_version!(); tracing::info!(target: GRAFBASE_TARGET, "Grafbase Gateway {crate_version}"); - federated_server::serve(args.listen_address(), config, args.fetch_method()?, otel_tracing).await?; + federated_server::serve(ServerConfig { + listen_addr: args.listen_address(), + config, + config_path: args.config_path().map(|p| p.to_owned()), + config_hot_reload: args.hot_reload(), + fetch_method: args.fetch_method()?, + otel_tracing, + }) + .await?; Ok::<(), anyhow::Error>(()) })?;