diff --git a/quickwit/quickwit-common/src/pubsub.rs b/quickwit/quickwit-common/src/pubsub.rs index 9a187f64409..2aeab590dbe 100644 --- a/quickwit/quickwit-common/src/pubsub.rs +++ b/quickwit/quickwit-common/src/pubsub.rs @@ -66,7 +66,7 @@ impl EventBroker { .inner .subscriptions .lock() - .expect("The lock should never be poisoned."); + .expect("the lock should not be poisoned"); if !subscriptions.contains::>() { subscriptions.insert::>(HashMap::new()); @@ -99,7 +99,7 @@ impl EventBroker { .inner .subscriptions .lock() - .expect("The lock should never be poisoned."); + .expect("the lock should not be poisoned"); if let Some(typed_subscriptions) = subscriptions.get::>() { for subscription in typed_subscriptions.values() { @@ -141,7 +141,7 @@ where E: Event let mut subscriptions = broker .subscriptions .lock() - .expect("The lock should never be poisoned."); + .expect("the lock should not be poisoned"); if let Some(typed_subscriptions) = subscriptions.get_mut::>() { typed_subscriptions.remove(&self.subscription_id); } @@ -178,20 +178,24 @@ mod tests { #[tokio::test] async fn test_event_broker() { - let broker = EventBroker::default(); + let event_broker = EventBroker::default(); let counter = Arc::new(AtomicUsize::new(0)); let subscriber = MySubscriber { counter: counter.clone(), }; - let subscription = broker.subscribe(subscriber); + let subscription_handle = event_broker.subscribe(subscriber); + let event = MyEvent { value: 42 }; - broker.publish(event); + event_broker.publish(event); + tokio::time::sleep(tokio::time::Duration::from_millis(1)).await; assert_eq!(counter.load(Ordering::Relaxed), 42); - subscription.cancel(); + subscription_handle.cancel(); + let event = MyEvent { value: 1337 }; - broker.publish(event); + event_broker.publish(event); + tokio::time::sleep(tokio::time::Duration::from_millis(1)).await; assert_eq!(counter.load(Ordering::Relaxed), 42); } diff --git a/quickwit/quickwit-common/src/tower/event_listener.rs b/quickwit/quickwit-common/src/tower/event_listener.rs new file mode 100644 index 00000000000..270e99ec633 --- /dev/null +++ b/quickwit/quickwit-common/src/tower/event_listener.rs @@ -0,0 +1,174 @@ +// Copyright (C) 2023 Quickwit, Inc. +// +// Quickwit is offered under the AGPL v3.0 and as commercial software. +// For commercial licensing, contact us at hello@quickwit.io. +// +// AGPL: +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::{ready, Future}; +use pin_project::pin_project; +use tower::{Layer, Service}; + +use crate::pubsub::{Event, EventBroker}; + +pub struct EventListener { + inner: S, + event_broker: EventBroker, +} + +impl EventListener { + pub fn new(inner: S, event_broker: EventBroker) -> Self { + Self { + inner, + event_broker, + } + } +} + +impl Service for EventListener +where + S: Service, + R: Event, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, request: R) -> Self::Future { + let inner = self.inner.call(request.clone()); + ResponseFuture { + inner, + event_broker: self.event_broker.clone(), + request: Some(request), + } + } +} + +#[derive(Debug, Clone)] +pub struct EventListenerLayer { + event_broker: EventBroker, +} + +impl EventListenerLayer { + pub fn new(event_broker: EventBroker) -> Self { + Self { event_broker } + } +} + +impl Layer for EventListenerLayer { + type Service = EventListener; + + fn layer(&self, service: S) -> Self::Service { + EventListener::new(service, self.event_broker.clone()) + } +} + +/// Response future for [`EventListener`]. +#[pin_project] +pub struct ResponseFuture { + #[pin] + inner: F, + event_broker: EventBroker, + request: Option, +} + +impl Future for ResponseFuture +where + R: Event, + F: Future>, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let response = ready!(this.inner.poll(cx)); + + if response.is_ok() { + this.event_broker + .publish(this.request.take().expect("request should be set")); + } + Poll::Ready(Ok(response?)) + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use std::time::Duration; + + use async_trait::async_trait; + + use super::*; + use crate::pubsub::EventSubscriber; + + #[derive(Debug, Clone, Copy)] + struct MyEvent { + return_ok: bool, + } + + impl Event for MyEvent {} + + #[derive(Debug, Clone)] + struct MySubscriber { + counter: Arc, + } + + #[async_trait] + impl EventSubscriber for MySubscriber { + async fn handle_event(&mut self, _event: MyEvent) { + self.counter.fetch_add(1, Ordering::Relaxed); + } + } + + #[tokio::test] + async fn test_event_listener() { + let event_broker = EventBroker::default(); + let counter = Arc::new(AtomicUsize::new(0)); + let subscriber = MySubscriber { + counter: counter.clone(), + }; + let _subscription_handle = event_broker.subscribe::(subscriber); + + let layer = EventListenerLayer::new(event_broker); + + let mut service = layer.layer(tower::service_fn(|request: MyEvent| async move { + if request.return_ok { + Ok(()) + } else { + Err(()) + } + })); + let request = MyEvent { return_ok: false }; + service.call(request).await.unwrap_err(); + + tokio::time::sleep(Duration::from_millis(1)).await; + assert_eq!(counter.load(Ordering::Relaxed), 0); + + let request = MyEvent { return_ok: true }; + service.call(request).await.unwrap(); + + tokio::time::sleep(Duration::from_millis(1)).await; + assert_eq!(counter.load(Ordering::Relaxed), 1); + + } +} diff --git a/quickwit/quickwit-common/src/tower/mod.rs b/quickwit/quickwit-common/src/tower/mod.rs index 3434bf925c7..1aacd7fec8c 100644 --- a/quickwit/quickwit-common/src/tower/mod.rs +++ b/quickwit/quickwit-common/src/tower/mod.rs @@ -22,6 +22,7 @@ mod box_service; mod buffer; mod change; mod estimate_rate; +mod event_listener; mod metrics; mod pool; mod rate; @@ -38,6 +39,7 @@ pub use box_service::BoxService; pub use buffer::{Buffer, BufferError, BufferLayer}; pub use change::Change; pub use estimate_rate::{EstimateRate, EstimateRateLayer}; +pub use event_listener::{EventListener, EventListenerLayer}; use futures::Future; pub use metrics::{PrometheusMetrics, PrometheusMetricsLayer}; pub use pool::Pool;