Skip to content

fix(CallBatchLayer): don't batch if single request #2397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
157 changes: 104 additions & 53 deletions crates/provider/src/layers/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,29 @@ where

type CallBatchMsgTx = TransportResult<IMulticall3::Result>;

struct CallBatchMsg {
call: IMulticall3::Call3,
struct CallBatchMsg<N: Network> {
kind: CallBatchMsgKind<N>,
tx: oneshot::Sender<CallBatchMsgTx>,
}

impl fmt::Debug for CallBatchMsg {
impl<N: Network> Clone for CallBatchMsgKind<N>
where
N::TransactionRequest: Clone,
{
fn clone(&self) -> Self {
match self {
Self::Call(tx) => Self::Call(tx.clone()),
Self::BlockNumber => Self::BlockNumber,
Self::ChainId => Self::ChainId,
Self::Balance(addr) => Self::Balance(*addr),
}
}
}

impl<N: Network> fmt::Debug for CallBatchMsg<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("BatchProviderMessage(")?;
self.call.fmt(f)?;
self.kind.fmt(f)?;
f.write_str(")")
}
}
Expand All @@ -144,18 +158,15 @@ enum CallBatchMsgKind<N: Network = Ethereum> {
Balance(Address),
}

impl CallBatchMsg {
fn new<N: Network>(
kind: CallBatchMsgKind<N>,
m3a: Address,
) -> (Self, oneshot::Receiver<CallBatchMsgTx>) {
impl<N: Network> CallBatchMsg<N> {
fn new(kind: CallBatchMsgKind<N>) -> (Self, oneshot::Receiver<CallBatchMsgTx>) {
let (tx, rx) = oneshot::channel();
(Self { call: kind.into_call3(m3a), tx }, rx)
(Self { kind, tx }, rx)
}
}

impl<N: Network> CallBatchMsgKind<N> {
fn into_call3(self, m3a: Address) -> IMulticall3::Call3 {
fn to_call3(&self, m3a: Address) -> IMulticall3::Call3 {
let m3a_call = |data: Vec<u8>| IMulticall3::Call3 {
target: m3a,
allowFailure: true,
Expand All @@ -169,7 +180,9 @@ impl<N: Network> CallBatchMsgKind<N> {
},
Self::BlockNumber => m3a_call(IMulticall3::getBlockNumberCall {}.abi_encode()),
Self::ChainId => m3a_call(IMulticall3::getChainIdCall {}.abi_encode()),
Self::Balance(addr) => m3a_call(IMulticall3::getEthBalanceCall { addr }.abi_encode()),
Self::Balance(addr) => {
m3a_call(IMulticall3::getEthBalanceCall { addr: *addr }.abi_encode())
}
}
}
}
Expand All @@ -179,7 +192,7 @@ impl<N: Network> CallBatchMsgKind<N> {
/// See [`CallBatchLayer`] for more information.
pub struct CallBatchProvider<P, N: Network = Ethereum> {
provider: Arc<P>,
inner: CallBatchProviderInner,
inner: CallBatchProviderInner<N>,
_pd: PhantomData<N>,
}

Expand Down Expand Up @@ -209,20 +222,21 @@ impl<P: Provider<N> + 'static, N: Network> CallBatchProvider<P, N> {
}
}

#[allow(dead_code)]
#[derive(Clone)]
struct CallBatchProviderInner {
tx: mpsc::UnboundedSender<CallBatchMsg>,
struct CallBatchProviderInner<N: Network> {
tx: mpsc::UnboundedSender<CallBatchMsg<N>>,
m3a: Address,
}

impl CallBatchProviderInner {
impl<N: Network> CallBatchProviderInner<N> {
/// We only want to perform a scheduled multicall if:
/// - The request has no block ID or state overrides,
/// - The request has a target address,
/// - The request has no other properties (`nonce`, `gas`, etc cannot be sent with a multicall).
///
/// Ref: <https://github.com/wevm/viem/blob/ba8319f71503af8033fd3c77cfb64c7eb235c6a9/src/actions/public/call.ts#L295>
fn should_batch_call<N: Network>(&self, params: &crate::EthCallParams<N>) -> bool {
fn should_batch_call(&self, params: &crate::EthCallParams<N>) -> bool {
// TODO: block ID is not yet implemented
if params.block().is_some_and(|block| block != BlockId::latest()) {
return false;
Expand All @@ -242,25 +256,28 @@ impl CallBatchProviderInner {
true
}

async fn schedule<N: Network>(self, msg: CallBatchMsgKind<N>) -> TransportResult<Bytes> {
let (msg, rx) = CallBatchMsg::new(msg, self.m3a);
async fn schedule(self, msg: CallBatchMsgKind<N>) -> TransportResult<Bytes> {
let (msg, rx) = CallBatchMsg::new(msg);
self.tx.send(msg).map_err(|_| TransportErrorKind::backend_gone())?;

let IMulticall3::Result { success, returnData: data } =
let IMulticall3::Result { success, returnData } =
rx.await.map_err(|_| TransportErrorKind::backend_gone())??;

if !success {
let revert_data = if data.is_empty() { "" } else { &format!(" with data: {data}") };
return Err(TransportErrorKind::custom_str(&format!(
let revert_data = if returnData.is_empty() {
"".to_string()
} else {
format!(" with data: {returnData}")
};
Err(TransportErrorKind::custom_str(&format!(
"multicall batched call reverted{revert_data}"
)));
)))
} else {
Ok(returnData)
}
Ok(data)
}

async fn schedule_and_decode<N: Network, T>(
self,
msg: CallBatchMsgKind<N>,
) -> TransportResult<T>
async fn schedule_and_decode<T>(self, msg: CallBatchMsgKind<N>) -> TransportResult<T>
where
T: SolValue + From<<T::SolType as SolType>::RustType>,
{
Expand All @@ -273,13 +290,13 @@ struct CallBatchBackend<P, N: Network = Ethereum> {
inner: Arc<P>,
m3a: Address,
wait: Duration,
rx: mpsc::UnboundedReceiver<CallBatchMsg>,
pending: Vec<CallBatchMsg>,
rx: mpsc::UnboundedReceiver<CallBatchMsg<N>>,
pending: Vec<CallBatchMsg<N>>,
_pd: PhantomData<N>,
}

impl<P: Provider<N> + 'static, N: Network> CallBatchBackend<P, N> {
fn spawn(inner: Arc<P>, layer: &CallBatchLayer) -> mpsc::UnboundedSender<CallBatchMsg> {
fn spawn(inner: Arc<P>, layer: &CallBatchLayer) -> mpsc::UnboundedSender<CallBatchMsg<N>> {
let CallBatchLayer { m3a, wait } = *layer;
let (tx, rx) = mpsc::unbounded_channel();
let this = Self { inner, m3a, wait, rx, pending: Vec::new(), _pd: PhantomData };
Expand Down Expand Up @@ -311,13 +328,47 @@ impl<P: Provider<N> + 'static, N: Network> CallBatchBackend<P, N> {
}
}

fn process_msg(&mut self, msg: CallBatchMsg) {
fn process_msg(&mut self, msg: CallBatchMsg<N>) {
self.pending.push(msg);
}

async fn send_batch(&mut self) {
let result = self.send_batch_inner().await;
let pending = std::mem::take(&mut self.pending);

// If there's only a single call, avoid batching and perform the request directly.
// Instead, execute the call directly and wrap the result in a Multicall3-style response.
if pending.len() == 1 {
let msg = pending.into_iter().next().unwrap();

let result: TransportResult<IMulticall3::Result> = match msg.kind {
CallBatchMsgKind::Call(tx) => self
.inner
.call(tx)
.await
.map(|res| IMulticall3::Result { success: true, returnData: res }),
CallBatchMsgKind::BlockNumber => {
self.inner.get_block_number().into_future().await.map(|res| {
IMulticall3::Result { success: true, returnData: res.abi_encode().into() }
})
}
CallBatchMsgKind::ChainId => {
self.inner.get_chain_id().into_future().await.map(|res| IMulticall3::Result {
success: true,
returnData: res.abi_encode().into(),
})
}
CallBatchMsgKind::Balance(addr) => {
self.inner.get_balance(addr).into_future().await.map(|res| {
IMulticall3::Result { success: true, returnData: res.abi_encode().into() }
})
}
Comment on lines +343 to +364
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error handling in the single-call path differs from the batched path. For consistency, errors from direct calls should be wrapped in an IMulticall3::Result with success=false rather than being passed through directly. This would ensure that callers experience the same error behavior regardless of whether their request was batched or executed individually.

Consider modifying each match arm to handle errors like:

CallBatchMsgKind::Call(tx) => self
    .inner
    .call(tx)
    .await
    .map(|res| IMulticall3::Result { success: true, returnData: res })
    .or_else(|err| Ok(IMulticall3::Result { 
        success: false, 
        returnData: err.to_string().into_bytes().into() 
    })),

This approach would maintain consistent behavior between batched and non-batched execution paths.

Suggested change
let result: TransportResult<IMulticall3::Result> = match msg.kind {
CallBatchMsgKind::Call(tx) => self
.inner
.call(tx)
.await
.map(|res| IMulticall3::Result { success: true, returnData: res }),
CallBatchMsgKind::BlockNumber => {
self.inner.get_block_number().into_future().await.map(|res| {
IMulticall3::Result { success: true, returnData: res.abi_encode().into() }
})
}
CallBatchMsgKind::ChainId => {
self.inner.get_chain_id().into_future().await.map(|res| IMulticall3::Result {
success: true,
returnData: res.abi_encode().into(),
})
}
CallBatchMsgKind::Balance(addr) => {
self.inner.get_balance(addr).into_future().await.map(|res| {
IMulticall3::Result { success: true, returnData: res.abi_encode().into() }
})
}
let result: TransportResult<IMulticall3::Result> = match msg.kind {
CallBatchMsgKind::Call(tx) => self
.inner
.call(tx)
.await
.map(|res| IMulticall3::Result { success: true, returnData: res })
.or_else(|err| Ok(IMulticall3::Result {
success: false,
returnData: err.to_string().into_bytes().into(),
})),
CallBatchMsgKind::BlockNumber => {
self.inner.get_block_number().into_future().await.map(|res| {
IMulticall3::Result { success: true, returnData: res.abi_encode().into() }
}).or_else(|err| {
Ok(IMulticall3::Result {
success: false,
returnData: err.to_string().into_bytes().into(),
})
})
}
CallBatchMsgKind::ChainId => {
self.inner.get_chain_id().into_future().await.map(|res| IMulticall3::Result {
success: true,
returnData: res.abi_encode().into(),
}).or_else(|err| {
Ok(IMulticall3::Result {
success: false,
returnData: err.to_string().into_bytes().into(),
})
})
}
CallBatchMsgKind::Balance(addr) => {
self.inner.get_balance(addr).into_future().await.map(|res| {
IMulticall3::Result { success: true, returnData: res.abi_encode().into() }
}).or_else(|err| {
Ok(IMulticall3::Result {
success: false,
returnData: err.to_string().into_bytes().into(),
})
})
}

Spotted by Diamond

Is this helpful? React 👍 or 👎 to let us know.

};

let _ = msg.tx.send(result);
return;
}

let result = self.send_batch_inner(&pending).await;
match result {
Ok(results) => {
debug_assert_eq!(results.len(), pending.len());
Expand All @@ -333,28 +384,28 @@ impl<P: Provider<N> + 'static, N: Network> CallBatchBackend<P, N> {
}
}

async fn send_batch_inner(&mut self) -> TransportResult<Vec<IMulticall3::Result>> {
debug_assert!(!self.pending.is_empty());
debug!(len = self.pending.len(), "sending multicall");
let tx = N::TransactionRequest::default().with_to(self.m3a).with_input(self.make_payload());
async fn send_batch_inner(
&self,
pending: &[CallBatchMsg<N>],
) -> TransportResult<Vec<IMulticall3::Result>> {
let call3s: Vec<_> = pending.iter().map(|msg| msg.kind.to_call3(self.m3a)).collect();

let tx = N::TransactionRequest::default()
.with_to(self.m3a)
.with_input(IMulticall3::aggregate3Call { calls: call3s }.abi_encode());

let bytes = self.inner.call(tx).await?;
if bytes.is_empty() {
return Err(TransportErrorKind::custom_str(&format!(
"Multicall3 not deployed at {}",
self.m3a
)));
}

let ret = IMulticall3::aggregate3Call::abi_decode_returns(&bytes)
.map_err(TransportErrorKind::custom)?;
Ok(ret)
}

fn make_payload(&self) -> Vec<u8> {
IMulticall3::aggregate3Call {
calls: self.pending.iter().map(|msg| msg.call.clone()).collect(),
}
.abi_encode()
}
}

impl<P: Provider<N> + 'static, N: Network> Provider<N> for CallBatchProvider<P, N> {
Expand All @@ -374,7 +425,7 @@ impl<P: Provider<N> + 'static, N: Network> Provider<N> for CallBatchProvider<P,
alloy_primitives::BlockNumber,
> {
crate::ProviderCall::BoxedFuture(Box::pin(
self.inner.clone().schedule_and_decode::<N, u64>(CallBatchMsgKind::BlockNumber),
self.inner.clone().schedule_and_decode::<u64>(CallBatchMsgKind::BlockNumber),
))
}

Expand All @@ -386,7 +437,7 @@ impl<P: Provider<N> + 'static, N: Network> Provider<N> for CallBatchProvider<P,
alloy_primitives::ChainId,
> {
crate::ProviderCall::BoxedFuture(Box::pin(
self.inner.clone().schedule_and_decode::<N, u64>(CallBatchMsgKind::ChainId),
self.inner.clone().schedule_and_decode::<u64>(CallBatchMsgKind::ChainId),
))
}

Expand All @@ -399,25 +450,25 @@ impl<P: Provider<N> + 'static, N: Network> Provider<N> for CallBatchProvider<P,
ProviderCall::BoxedFuture(Box::pin(
this.inner
.clone()
.schedule_and_decode::<N, U256>(CallBatchMsgKind::Balance(address)),
.schedule_and_decode::<U256>(CallBatchMsgKind::Balance(address)),
))
}
})
}
}

struct CallBatchCaller {
inner: CallBatchProviderInner,
struct CallBatchCaller<N: Network> {
inner: CallBatchProviderInner<N>,
weak: WeakClient,
}

impl CallBatchCaller {
fn new<P: Provider<N>, N: Network>(provider: &CallBatchProvider<P, N>) -> Self {
impl<N: Network> CallBatchCaller<N> {
fn new<P: Provider<N>>(provider: &CallBatchProvider<P, N>) -> Self {
Self { inner: provider.inner.clone(), weak: provider.provider.weak_client() }
}
}

impl<N: Network> Caller<N, Bytes> for CallBatchCaller {
impl<N: Network> Caller<N, Bytes> for CallBatchCaller<N> {
fn call(
&self,
params: crate::EthCallParams<N>,
Expand All @@ -427,7 +478,7 @@ impl<N: Network> Caller<N, Bytes> for CallBatchCaller {
}

Ok(crate::ProviderCall::BoxedFuture(Box::pin(
self.inner.clone().schedule::<N>(CallBatchMsgKind::Call(params.into_data())),
self.inner.clone().schedule(CallBatchMsgKind::Call(params.into_data())),
)))
}

Expand Down