Skip to content

feat: add support for dynamic rate limit configurations with hot reload #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: pr3-base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions cli/crates/federated-dev/src/dev/gateway_nanny.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use futures_concurrency::stream::Merge;
use futures_util::{future::BoxFuture, stream::BoxStream, FutureExt as _, StreamExt};
use runtime::rate_limiting::KeyedRateLimitConfig;
use runtime_local::rate_limiting::in_memory::key_based::InMemoryRateLimiter;
use tokio::sync::mpsc;
use tokio_stream::wrappers::WatchStream;

/// The GatewayNanny looks after the `Gateway` - on updates to the graph or config it'll
Expand Down Expand Up @@ -56,7 +57,7 @@ pub(super) async fn new_gateway(config: Option<engine_v2::VersionedConfig>) -> O
.into_iter()
.map(|(k, v)| {
(
k,
k.to_string(),
runtime::rate_limiting::GraphRateLimit {
limit: v.limit,
duration: v.duration,
Expand All @@ -65,14 +66,16 @@ pub(super) async fn new_gateway(config: Option<engine_v2::VersionedConfig>) -> O
})
.collect::<HashMap<_, _>>();

let (_, rx) = mpsc::channel(100);

let runtime = CliRuntime {
fetcher: runtime_local::NativeFetcher::runtime_fetcher(),
trusted_documents: runtime::trusted_documents_client::Client::new(
runtime_noop::trusted_documents::NoopTrustedDocuments,
),
kv: runtime_local::InMemoryKvStore::runtime(),
meter: grafbase_telemetry::metrics::meter_from_global_provider(),
rate_limiter: InMemoryRateLimiter::runtime(KeyedRateLimitConfig { rate_limiting_configs }),
rate_limiter: InMemoryRateLimiter::runtime(KeyedRateLimitConfig { rate_limiting_configs }, rx),
};

let schema = config.try_into().ok()?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use runtime_local::{
rate_limiting::in_memory::key_based::InMemoryRateLimiter, InMemoryHotCacheFactory, InMemoryKvStore, NativeFetcher,
};
use runtime_noop::trusted_documents::NoopTrustedDocuments;
use tokio::sync::mpsc;

pub struct TestRuntime {
pub fetcher: runtime::fetch::Fetcher,
Expand All @@ -16,13 +17,14 @@ pub struct TestRuntime {

impl Default for TestRuntime {
fn default() -> Self {
let (_, rx) = mpsc::channel(100);
Self {
fetcher: NativeFetcher::runtime_fetcher(),
trusted_documents: trusted_documents_client::Client::new(NoopTrustedDocuments),
kv: InMemoryKvStore::runtime(),
meter: metrics::meter_from_global_provider(),
hooks: Default::default(),
rate_limiter: InMemoryRateLimiter::runtime(Default::default()),
rate_limiter: InMemoryRateLimiter::runtime(Default::default(), rx),
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use std::collections::HashMap;
use std::net::IpAddr;
use std::num::NonZeroU32;
use std::sync::Arc;
use std::{collections::HashMap, sync::RwLock};

use futures_util::future::BoxFuture;
use futures_util::FutureExt;
use governor::Quota;
use grafbase_telemetry::span::GRAFBASE_TARGET;
use serde_json::Value;

use http::{HeaderName, HeaderValue};
use runtime::rate_limiting::{Error, GraphRateLimit, KeyedRateLimitConfig, RateLimiter, RateLimiterContext};
use tokio::sync::mpsc;

pub struct RateLimitingContext(pub String);

Expand All @@ -34,48 +37,72 @@ impl RateLimiterContext for RateLimitingContext {
}
}

#[derive(Default)]
pub struct InMemoryRateLimiter {
inner: HashMap<String, governor::DefaultKeyedRateLimiter<usize>>,
limiters: Arc<RwLock<HashMap<String, governor::DefaultKeyedRateLimiter<usize>>>>,
}

impl InMemoryRateLimiter {
pub fn runtime(config: KeyedRateLimitConfig<'_>) -> RateLimiter {
let mut limiter = Self::default();
pub fn runtime(
config: KeyedRateLimitConfig,
mut updates: mpsc::Receiver<HashMap<String, GraphRateLimit>>,
) -> RateLimiter {
let mut limiters = HashMap::new();

// add subgraph rate limiting configuration
for (name, rate_limit_config) in config.rate_limiting_configs {
limiter = limiter.with_rate_limiter(name, rate_limit_config);
for (name, config) in config.rate_limiting_configs {
let Some(limiter) = create_limiter(config) else {
continue;
};

limiters.insert(name.to_string(), limiter);
}

RateLimiter::new(limiter)
}
let limiters = Arc::new(RwLock::new(limiters));
let limiters_copy = limiters.clone();

tokio::spawn(async move {
while let Some(updates) = updates.recv().await {
let mut limiters = limiters_copy.write().unwrap();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Performance
Using unwrap() on the result of RwLock::write() can lead to a panic if the lock is poisoned. Consider handling the error gracefully instead.

Suggested change
let mut limiters = limiters_copy.write().unwrap();
Replace `let mut limiters = limiters_copy.write().unwrap();` with `let mut limiters = limiters_copy.write().map_err(|_| Error::LockPoisoned)?;` to handle potential poisoning.


for (name, config) in updates {
let Some(limiter) = create_limiter(config) else {
continue;
};

limiters.insert(name.to_string(), limiter);
}
}
});

pub fn with_rate_limiter(mut self, key: &str, rate_limit_config: GraphRateLimit) -> Self {
let quota = (rate_limit_config.limit as u64)
.checked_div(rate_limit_config.duration.as_secs())
.expect("rate limiter with invalid per second quota");

self.inner.insert(
key.to_string(),
governor::RateLimiter::keyed(Quota::per_second(
NonZeroU32::new(quota as u32).expect("rate limit duration cannot be 0"),
)),
);
self
RateLimiter::new(Self { limiters })
}
}

fn create_limiter(rate_limit_config: GraphRateLimit) -> Option<governor::DefaultKeyedRateLimiter<usize>> {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🐛 Bug
The create_limiter function does not handle the case where the rate_limit_config.duration is zero, which could lead to a division by zero error when calculating the quota. This should be checked before performing the division.

Suggested change
fn create_limiter(rate_limit_config: GraphRateLimit) -> Option<governor::DefaultKeyedRateLimiter<usize>> {
let Some(quota) = (rate_limit_config.limit as u64).checked_div(rate_limit_config.duration.as_secs()).filter(|&d| d > 0) else { tracing::error!(target: GRAFBASE_TARGET, "the duration for rate limit cannot be zero"); return None; };

let Some(quota) = (rate_limit_config.limit as u64).checked_div(rate_limit_config.duration.as_secs()) else {
tracing::error!(target: GRAFBASE_TARGET, "the duration for rate limit cannot be zero");
return None;
};

let Some(quota) = NonZeroU32::new(quota as u32) else {
tracing::error!(target: GRAFBASE_TARGET, "the limit is too low per defined duration");
return None;
};

Some(governor::RateLimiter::keyed(Quota::per_second(quota)))
}

impl runtime::rate_limiting::RateLimiterInner for InMemoryRateLimiter {
fn limit<'a>(&'a self, context: &'a dyn RateLimiterContext) -> BoxFuture<'a, Result<(), Error>> {
async {
if let Some(key) = context.key() {
if let Some(rate_limiter) = self.inner.get(key) {
rate_limiter
.check_key(&usize::MIN)
.map_err(|_err| Error::ExceededCapacity)?;
};
}
let Some(key) = context.key() else { return Ok(()) };
let limiters = self.limiters.read().unwrap();

if let Some(rate_limiter) = limiters.get(key) {
rate_limiter
.check_key(&usize::MIN)
.map_err(|_err| Error::ExceededCapacity)?;
};

Ok(())
}
Expand Down
27 changes: 9 additions & 18 deletions engine/crates/runtime-local/src/rate_limiting/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use futures_util::future::BoxFuture;
use grafbase_telemetry::span::GRAFBASE_TARGET;
use redis::ClientTlsConfig;
use runtime::rate_limiting::{Error, GraphRateLimit, RateLimiter, RateLimiterContext};
use tokio::sync::watch;

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct RateLimitRedisConfig<'a> {
Expand All @@ -30,6 +31,8 @@ pub struct RateLimitRedisTlsConfig<'a> {
pub ca: Option<&'a Path>,
}

pub type Limits = watch::Receiver<HashMap<String, GraphRateLimit>>;

/// Rate limiter by utilizing Redis as a backend. It uses a averaging fixed window algorithm
/// to define is the limit reached or not.
///
Expand All @@ -47,7 +50,7 @@ pub struct RateLimitRedisTlsConfig<'a> {
pub struct RedisRateLimiter {
pool: Pool<pool::Manager>,
key_prefix: String,
subgraph_limits: HashMap<String, GraphRateLimit>,
limits: Limits,
}

enum Key<'a> {
Expand All @@ -69,18 +72,11 @@ impl<'a> fmt::Display for Key<'a> {
}

impl RedisRateLimiter {
pub async fn runtime(
config: RateLimitRedisConfig<'_>,
subgraph_limits: impl IntoIterator<Item = (&str, GraphRateLimit)>,
) -> anyhow::Result<RateLimiter> {
let inner = Self::new(config, subgraph_limits).await?;
Ok(RateLimiter::new(inner))
pub async fn runtime(config: RateLimitRedisConfig<'_>, limits: Limits) -> anyhow::Result<RateLimiter> {
Ok(RateLimiter::new(Self::new(config, limits).await?))
}

pub async fn new(
config: RateLimitRedisConfig<'_>,
subgraph_limits: impl IntoIterator<Item = (&str, GraphRateLimit)>,
) -> anyhow::Result<RedisRateLimiter> {
pub async fn new(config: RateLimitRedisConfig<'_>, limits: Limits) -> anyhow::Result<RedisRateLimiter> {
let tls_config = match config.tls {
Some(tls) => {
let client_tls = match tls.cert.zip(tls.key) {
Expand Down Expand Up @@ -144,15 +140,10 @@ impl RedisRateLimiter {
}
};

let subgraph_limits = subgraph_limits
.into_iter()
.map(|(key, value)| (key.to_string(), value))
.collect();

Ok(Self {
pool,
key_prefix: config.key_prefix.to_string(),
subgraph_limits,
limits,
})
}

Expand All @@ -167,7 +158,7 @@ impl RedisRateLimiter {
async fn limit_inner(&self, context: &dyn RateLimiterContext) -> Result<(), Error> {
let Some(key) = context.key() else { return Ok(()) };

let Some(config) = self.subgraph_limits.get(key) else {
let Some(config) = self.limits.borrow().get(key).copied() else {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion
Consider handling the case where self.limits.borrow() might fail, to avoid potential panics.

return Ok(());
};

Expand Down
4 changes: 2 additions & 2 deletions engine/crates/runtime/src/rate_limiting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ pub struct GraphRateLimit {
}

#[derive(Debug, Clone, Default)]
pub struct KeyedRateLimitConfig<'a> {
pub rate_limiting_configs: HashMap<&'a str, GraphRateLimit>,
pub struct KeyedRateLimitConfig {
pub rate_limiting_configs: HashMap<String, GraphRateLimit>,
}
2 changes: 2 additions & 0 deletions gateway/crates/federated-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ axum-aws-lambda = { version = "0.7.0", optional = true }
tower = { workspace = true, optional = true }
lambda_http = { version = "0.11.1", optional = true }
serde_regex = "1.1.0"
notify = "6.1.1"
toml = "0.8.12"

[dev-dependencies]
indoc = "2.0.5"
Expand Down
27 changes: 26 additions & 1 deletion gateway/crates/federated-server/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@ mod cors;
mod entity_caching;
mod header;
mod health;
pub(crate) mod hot_reload;
mod rate_limit;

use std::{collections::BTreeMap, net::SocketAddr, path::PathBuf, time::Duration};
use std::{
collections::{BTreeMap, HashMap},
net::SocketAddr,
path::PathBuf,
time::Duration,
};

pub use self::health::HealthConfig;
use ascii::AsciiString;
Expand Down Expand Up @@ -67,6 +73,25 @@ pub struct Config {
pub entity_caching: EntityCachingConfig,
}

impl Config {
/// Load the rate limit configuration for global and subgraph level settings.
pub fn as_keyed_rate_limit_config(&self) -> HashMap<&str, GraphRateLimit> {
let mut key_based_config = HashMap::new();

if let Some(global_config) = self.gateway.rate_limit.as_ref().and_then(|c| c.global) {
key_based_config.insert("global", global_config);
}

for (subgraph_name, subgraph) in self.subgraphs.iter() {
if let Some(limit) = subgraph.rate_limit {
key_based_config.insert(subgraph_name, limit);
}
}

key_based_config
}
}

#[derive(Debug, Default, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct GatewayConfig {
Expand Down
Loading