Skip to content

Commit

Permalink
Remove InternalEventHandler (#2952)
Browse files Browse the repository at this point in the history
Since having both a regular event handler and a raw handler at the same time is
now supported, the `InternalEventHandler` enum is no longer needed to ensure
mutual exclusivity. 

Additionally, there was a subtle bug introduced when the `Both` variant was
added: ratelimit events and shard stage update events would not be dispatched
because the conditionals to dispatch those events only checked for the `Normal`
variant. Inlining the regular/raw event handler fields improves ergonomics and
fixes said bug.
  • Loading branch information
mkrasnitski authored Aug 26, 2024
1 parent 849daf3 commit c3d4a33
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 65 deletions.
19 changes: 5 additions & 14 deletions src/gateway/client/dispatch.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#[cfg(feature = "framework")]
use std::sync::Arc;

use super::event_handler::InternalEventHandler;
use super::event_handler::{EventHandler, RawEventHandler};
use super::{Context, FullEvent};
#[cfg(feature = "cache")]
use crate::cache::{Cache, CacheUpdate};
Expand Down Expand Up @@ -48,18 +47,10 @@ pub(crate) async fn dispatch_model(
event: Event,
context: Context,
#[cfg(feature = "framework")] framework: Option<Arc<dyn Framework>>,
event_handler: Option<InternalEventHandler>,
event_handler: Option<Arc<dyn EventHandler>>,
raw_event_handler: Option<Arc<dyn RawEventHandler>>,
) {
let (handler, raw_handler) = match event_handler {
Some(InternalEventHandler::Normal(handler)) => (Some(handler), None),
Some(InternalEventHandler::Both {
raw,
normal,
}) => (Some(normal), Some(raw)),
Some(InternalEventHandler::Raw(raw_handler)) => (None, Some(raw_handler)),
None => (None, None),
};
if let Some(raw_handler) = raw_handler {
if let Some(raw_handler) = raw_event_handler {
raw_handler.raw_event(context.clone(), &event).await;
}

Expand All @@ -78,7 +69,7 @@ pub(crate) async fn dispatch_model(
framework.dispatch(&context, &full_event).await;
}

if let Some(handler) = handler {
if let Some(handler) = event_handler {
if let Some(extra_event) = extra_event {
extra_event.dispatch(context.clone(), &*handler).await;
}
Expand Down
8 changes: 0 additions & 8 deletions src/gateway/client/event_handler.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::collections::VecDeque;
#[cfg(feature = "cache")]
use std::num::NonZeroU16;
use std::sync::Arc;

use async_trait::async_trait;
use strum::{EnumCount, IntoStaticStr, VariantNames};
Expand Down Expand Up @@ -535,10 +534,3 @@ pub trait RawEventHandler: Send + Sync {
true
}
}

#[derive(Clone)]
pub enum InternalEventHandler {
Raw(Arc<dyn RawEventHandler>),
Normal(Arc<dyn EventHandler>),
Both { raw: Arc<dyn RawEventHandler>, normal: Arc<dyn EventHandler> },
}
17 changes: 4 additions & 13 deletions src/gateway/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use futures::StreamExt as _;
use tracing::debug;

pub use self::context::Context;
pub use self::event_handler::{EventHandler, FullEvent, InternalEventHandler, RawEventHandler};
pub use self::event_handler::{EventHandler, FullEvent, RawEventHandler};
#[cfg(feature = "cache")]
use crate::cache::Cache;
#[cfg(feature = "cache")]
Expand Down Expand Up @@ -268,18 +268,8 @@ impl IntoFuture for ClientBuilder {
let presence = self.presence;
let http = self.http;

let event_handler = match (self.event_handler, self.raw_event_handler) {
(Some(normal), Some(raw)) => Some(InternalEventHandler::Both {
normal,
raw,
}),
(Some(h), None) => Some(InternalEventHandler::Normal(h)),
(None, Some(h)) => Some(InternalEventHandler::Raw(h)),
(None, None) => None,
};

if let Some(ratelimiter) = &http.ratelimiter {
if let Some(InternalEventHandler::Normal(event_handler)) = &event_handler {
if let Some(event_handler) = &self.event_handler {
let event_handler = Arc::clone(event_handler);
ratelimiter.set_ratelimit_callback(Box::new(move |info| {
let event_handler = Arc::clone(&event_handler);
Expand Down Expand Up @@ -313,7 +303,8 @@ impl IntoFuture for ClientBuilder {
let framework_cell = Arc::new(OnceLock::new());
let (shard_manager, shard_manager_ret_value) = ShardManager::new(ShardManagerOptions {
data: Arc::clone(&data),
event_handler,
event_handler: self.event_handler,
raw_event_handler: self.raw_event_handler,
#[cfg(feature = "framework")]
framework: Arc::clone(&framework_cell),
#[cfg(feature = "voice")]
Expand Down
13 changes: 8 additions & 5 deletions src/gateway/sharding/shard_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use super::{ShardId, ShardQueue, ShardQueuer, ShardQueuerMessage, ShardRunnerInf
use crate::cache::Cache;
#[cfg(feature = "framework")]
use crate::framework::Framework;
use crate::gateway::client::InternalEventHandler;
use crate::gateway::client::{EventHandler, RawEventHandler};
#[cfg(feature = "voice")]
use crate::gateway::VoiceGatewayManager;
use crate::gateway::{ConnectionStage, GatewayError, PresenceData};
Expand Down Expand Up @@ -49,7 +49,7 @@ use crate::model::gateway::GatewayIntents;
/// use std::env;
/// use std::sync::{Arc, OnceLock};
///
/// use serenity::gateway::client::{EventHandler, InternalEventHandler, RawEventHandler};
/// use serenity::gateway::client::EventHandler;
/// use serenity::gateway::{ShardManager, ShardManagerOptions};
/// use serenity::http::Http;
/// use serenity::model::gateway::GatewayIntents;
Expand All @@ -66,12 +66,13 @@ use crate::model::gateway::GatewayIntents;
/// let data = Arc::new(());
/// let shard_total = gateway_info.shards;
/// let ws_url = Arc::from(gateway_info.url);
/// let event_handler = Arc::new(Handler) as Arc<dyn EventHandler>;
/// let event_handler = Arc::new(Handler);
/// let max_concurrency = std::num::NonZeroU16::MIN;
///
/// ShardManager::new(ShardManagerOptions {
/// data,
/// event_handler: Some(InternalEventHandler::Normal(event_handler)),
/// event_handler: Some(event_handler),
/// raw_event_handler: None,
/// framework: Arc::new(OnceLock::new()),
/// # #[cfg(feature = "voice")]
/// # voice_manager: None,
Expand Down Expand Up @@ -128,6 +129,7 @@ impl ShardManager {
let mut shard_queuer = ShardQueuer {
data: opt.data,
event_handler: opt.event_handler,
raw_event_handler: opt.raw_event_handler,
#[cfg(feature = "framework")]
framework: opt.framework,
last_start: None,
Expand Down Expand Up @@ -356,7 +358,8 @@ impl Drop for ShardManager {

pub struct ShardManagerOptions {
pub data: Arc<dyn std::any::Any + Send + Sync>,
pub event_handler: Option<InternalEventHandler>,
pub event_handler: Option<Arc<dyn EventHandler>>,
pub raw_event_handler: Option<Arc<dyn RawEventHandler>>,
#[cfg(feature = "framework")]
pub framework: Arc<OnceLock<Arc<dyn Framework>>>,
#[cfg(feature = "voice")]
Expand Down
12 changes: 6 additions & 6 deletions src/gateway/sharding/shard_queuer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use super::{
use crate::cache::Cache;
#[cfg(feature = "framework")]
use crate::framework::Framework;
use crate::gateway::client::InternalEventHandler;
use crate::gateway::client::{EventHandler, RawEventHandler};
#[cfg(feature = "voice")]
use crate::gateway::VoiceGatewayManager;
use crate::gateway::{ConnectionStage, PresenceData, Shard, ShardRunnerMessage};
Expand All @@ -42,11 +42,10 @@ pub struct ShardQueuer {
///
/// [`Client::data`]: crate::Client::data
pub data: Arc<dyn std::any::Any + Send + Sync>,
/// A reference to [`EventHandler`] or [`RawEventHandler`].
///
/// [`EventHandler`]: crate::gateway::client::EventHandler
/// [`RawEventHandler`]: crate::gateway::client::RawEventHandler
pub event_handler: Option<InternalEventHandler>,
/// A reference to an [`EventHandler`].
pub event_handler: Option<Arc<dyn EventHandler>>,
/// A reference to a [`RawEventHandler`].
pub raw_event_handler: Option<Arc<dyn RawEventHandler>>,
/// A copy of the framework
#[cfg(feature = "framework")]
pub framework: Arc<OnceLock<Arc<dyn Framework>>>,
Expand Down Expand Up @@ -223,6 +222,7 @@ impl ShardQueuer {
let mut runner = ShardRunner::new(ShardRunnerOptions {
data: Arc::clone(&self.data),
event_handler: self.event_handler.clone(),
raw_event_handler: self.raw_event_handler.clone(),
#[cfg(feature = "framework")]
framework: self.framework.get().cloned(),
manager: Arc::clone(&self.manager),
Expand Down
35 changes: 16 additions & 19 deletions src/gateway/sharding/shard_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::cache::Cache;
#[cfg(feature = "framework")]
use crate::framework::Framework;
use crate::gateway::client::dispatch::dispatch_model;
use crate::gateway::client::{Context, InternalEventHandler};
use crate::gateway::client::{Context, EventHandler, RawEventHandler};
#[cfg(feature = "voice")]
use crate::gateway::VoiceGatewayManager;
use crate::gateway::{ActivityData, ChunkGuildFilter, GatewayError};
Expand All @@ -30,7 +30,8 @@ use crate::model::user::OnlineStatus;
/// A runner for managing a [`Shard`] and its respective WebSocket client.
pub struct ShardRunner {
data: Arc<dyn std::any::Any + Send + Sync>,
event_handler: Option<InternalEventHandler>,
event_handler: Option<Arc<dyn EventHandler>>,
raw_event_handler: Option<Arc<dyn RawEventHandler>>,
#[cfg(feature = "framework")]
framework: Option<Arc<dyn Framework>>,
manager: Arc<ShardManager>,
Expand Down Expand Up @@ -58,6 +59,7 @@ impl ShardRunner {
runner_tx: tx,
data: opt.data,
event_handler: opt.event_handler,
raw_event_handler: opt.raw_event_handler,
#[cfg(feature = "framework")]
framework: opt.framework,
manager: opt.manager,
Expand Down Expand Up @@ -120,7 +122,7 @@ impl ShardRunner {
if post != pre {
self.update_manager().await;

if let Some(InternalEventHandler::Normal(event_handler)) = &self.event_handler {
if let Some(event_handler) = &self.event_handler {
let event_handler = Arc::clone(event_handler);
let context = self.make_context();
let event = ShardStageUpdateEvent {
Expand Down Expand Up @@ -173,21 +175,14 @@ impl ShardRunner {

if let Some(event) = event {
let context = self.make_context();
let can_dispatch = match &self.event_handler {
Some(InternalEventHandler::Normal(handler)) => {
handler.filter_event(&context, &event)
},
Some(InternalEventHandler::Raw(handler)) => {
handler.filter_event(&context, &event)
},
Some(InternalEventHandler::Both {
raw,
normal,
}) => {
raw.filter_event(&context, &event) && normal.filter_event(&context, &event)
},
None => true,
};
let can_dispatch = self
.event_handler
.as_ref()
.map_or(true, |handler| handler.filter_event(&context, &event))
&& self
.raw_event_handler
.as_ref()
.map_or(true, |handler| handler.filter_event(&context, &event));

if can_dispatch {
#[cfg(feature = "collector")]
Expand All @@ -214,6 +209,7 @@ impl ShardRunner {
#[cfg(feature = "framework")]
self.framework.clone(),
self.event_handler.clone(),
self.raw_event_handler.clone(),
),
);
}
Expand Down Expand Up @@ -509,7 +505,8 @@ impl ShardRunner {
/// Options to be passed to [`ShardRunner::new`].
pub struct ShardRunnerOptions {
pub data: Arc<dyn std::any::Any + Send + Sync>,
pub event_handler: Option<InternalEventHandler>,
pub event_handler: Option<Arc<dyn EventHandler>>,
pub raw_event_handler: Option<Arc<dyn RawEventHandler>>,
#[cfg(feature = "framework")]
pub framework: Option<Arc<dyn Framework>>,
pub manager: Arc<ShardManager>,
Expand Down

0 comments on commit c3d4a33

Please sign in to comment.