Skip to content

Commit d8a8f7f

Browse files
authored
feat: update batching logic to check for builtin weights when closing agg batch (#828)
1 parent 87a80fa commit d8a8f7f

File tree

7 files changed

+276
-12
lines changed

7 files changed

+276
-12
lines changed

orchestrator/src/core/config.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ use crate::{
4242
OrchestratorError, OrchestratorResult,
4343
};
4444

45+
use crate::types::batch::AggregatorBatchWeights;
4546
use blockifier::bouncer::BouncerWeights;
4647

4748
/// Starknet versions supported by the service
@@ -119,6 +120,7 @@ pub struct ConfigParam {
119120
/// * Aggregator Proof
120121
pub store_audit_artifacts: bool,
121122
pub bouncer_weights_limit: BouncerWeights,
123+
pub aggregator_batch_weights_limit: AggregatorBatchWeights,
122124
}
123125

124126
/// The app config. It can be accessed from anywhere inside the service
@@ -203,6 +205,8 @@ impl Config {
203205
let settlement_config = SettlementConfig::try_from(run_cmd.clone())
204206
.context("Failed to create settlement config from run command")?;
205207

208+
let bouncer_weights_limit = Self::load_bouncer_weights_limit(&run_cmd.bouncer_weights_limit_file)?;
209+
206210
let layer = run_cmd.layer.clone();
207211

208212
let params = ConfigParam {
@@ -221,7 +225,8 @@ impl Config {
221225
prover_layout_name: Self::get_layout_name(run_cmd.proving_layout_args.prover_layout_name.clone().as_str())
222226
.context("Failed to get prover layout name")?,
223227
store_audit_artifacts: run_cmd.store_audit_artifacts,
224-
bouncer_weights_limit: Self::load_bouncer_weights_limit(&run_cmd.bouncer_weights_limit_file)?,
228+
aggregator_batch_weights_limit: AggregatorBatchWeights::from(&bouncer_weights_limit),
229+
bouncer_weights_limit,
225230
};
226231
let rpc_client = JsonRpcClient::new(HttpTransport::new(params.madara_rpc_url.clone()));
227232
let feeder_gateway_client = RestClient::new(params.madara_feeder_gateway_url.clone());

orchestrator/src/tests/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::core::config::{Config, ConfigParam, StarknetVersion};
1212
use crate::core::{DatabaseClient, QueueClient, StorageClient};
1313
use crate::server::{get_server_url, setup_server};
1414
use crate::tests::common::{create_queues, create_sns_arn, drop_database};
15+
use crate::types::batch::AggregatorBatchWeights;
1516
use crate::types::constant::BLOB_LEN;
1617
use crate::types::params::batching::BatchingParams;
1718
use crate::types::params::cloud_provider::AWSCredentials;
@@ -26,6 +27,7 @@ use crate::types::Layer;
2627
use crate::utils::rest_client::RestClient;
2728
use alloy::primitives::Address;
2829
use axum::Router;
30+
use blockifier::bouncer::BouncerWeights;
2931
use cairo_vm::types::layout_name::LayoutName;
3032
use generate_pie::constants::{DEFAULT_SEPOLIA_ETH_FEE_TOKEN, DEFAULT_SEPOLIA_STRK_FEE_TOKEN};
3133
use httpmock::MockServer;
@@ -740,6 +742,7 @@ pub(crate) fn get_env_params() -> EnvParams {
740742
.parse::<bool>()
741743
.unwrap_or(false),
742744
bouncer_weights_limit: Default::default(), // Use default bouncer weights for tests
745+
aggregator_batch_weights_limit: AggregatorBatchWeights::from(&BouncerWeights::default()),
743746
};
744747

745748
let instrumentation_params = OTELConfig {

orchestrator/src/tests/jobs/state_update_job/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::error::job::state_update::StateUpdateError;
77
use crate::error::job::JobError;
88
use crate::tests::common::default_job_item;
99
use crate::tests::config::{ConfigType, TestConfigBuilder};
10-
use crate::types::batch::AggregatorBatch;
10+
use crate::types::batch::{AggregatorBatch, AggregatorBatchWeights};
1111
use crate::types::constant::{
1212
BLOB_DATA_FILE_NAME, PROGRAM_OUTPUT_FILE_NAME, SNOS_OUTPUT_FILE_NAME, STORAGE_ARTIFACTS_DIR, STORAGE_BLOB_DIR,
1313
};
@@ -87,6 +87,7 @@ async fn test_process_job_works(
8787
String::from(""),
8888
String::from(""),
8989
String::from(""),
90+
AggregatorBatchWeights::default(),
9091
"0.13.2".to_string(),
9192
)])
9293
});
@@ -253,6 +254,7 @@ async fn process_job_works_unit_test() {
253254
String::from(""),
254255
String::from(""),
255256
String::from(""),
257+
AggregatorBatchWeights::default(),
256258
"0.13.2".to_string(),
257259
)])
258260
});

orchestrator/src/tests/utils.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::types::batch::{AggregatorBatch, AggregatorBatchStatus};
1+
use crate::types::batch::{AggregatorBatch, AggregatorBatchStatus, AggregatorBatchWeights};
22
use chrono::{SubsecRound, Utc};
33
use rstest::fixture;
44
use uuid::Uuid;
@@ -104,6 +104,7 @@ pub fn build_batch(
104104
updated_at: Utc::now().round_subsecs(0),
105105
bucket_id: String::from("ABCD1234"),
106106
status: AggregatorBatchStatus::Open,
107+
builtin_weights: AggregatorBatchWeights::default(),
107108
starknet_version: "0.13.2".to_string(),
108109
}
109110
}

orchestrator/src/tests/workers/batching/mod.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ async fn test_batching_worker(#[case] has_existing_batch: bool) -> Result<(), Bo
2929
let mut database = MockDatabaseClient::new();
3030
let mut storage = MockStorageClient::new();
3131
let mut lock = MockLockClient::new();
32+
33+
let provider_url = format!("http://localhost:{}", server.port());
34+
3235
let start_block;
3336
let end_block;
3437

@@ -143,14 +146,22 @@ async fn test_batching_worker(#[case] has_existing_batch: bool) -> Result<(), Bo
143146

144147
let mut prover_client = MockProverClient::new();
145148
if !has_existing_batch {
146-
prover_client.expect_submit_task().times(1).returning(|_| Ok("bucket_id".to_string()));
149+
prover_client.expect_submit_task().times(2).returning(|_| Ok("bucket_id".to_string()));
147150
}
148151

152+
// Mock builtin weights calls for each block
153+
let builtin_weights = get_dummy_builtin_weights();
154+
server.mock(|when, then| {
155+
when.path("/feeder_gateway/get_block_bouncer_weights");
156+
then.status(200).body(serde_json::to_vec(&json!({"bouncer_weights": builtin_weights})).unwrap());
157+
});
158+
149159
let services = TestConfigBuilder::new()
150160
.configure_starknet_client(provider.into())
151-
.configure_database(database.into())
152-
.configure_storage_client(storage.into())
161+
.configure_madara_feeder_gateway_url(&provider_url)
153162
.configure_prover_client(prover_client.into())
163+
.configure_storage_client(storage.into())
164+
.configure_database(database.into())
154165
.configure_lock_client(lock.into())
155166
.build()
156167
.await;
@@ -171,6 +182,8 @@ async fn test_batching_worker_with_multiple_blocks() -> Result<(), Box<dyn Error
171182
let mut storage = MockStorageClient::new();
172183
let mut lock = MockLockClient::new();
173184

185+
let provider_url = format!("http://localhost:{}", server.port());
186+
174187
let existing_aggregator_batch = crate::types::batch::AggregatorBatch {
175188
index: 1,
176189
start_block: 0,
@@ -306,8 +319,16 @@ async fn test_batching_worker_with_multiple_blocks() -> Result<(), Box<dyn Error
306319
let mut prover_client = MockProverClient::new();
307320
prover_client.expect_submit_task().times(2).returning(|_| Ok("new_bucket_id".to_string()));
308321

322+
// Mock builtin weights calls for each block
323+
let builtin_weights = get_dummy_builtin_weights();
324+
server.mock(|when, then| {
325+
when.path("/feeder_gateway/get_block_bouncer_weights");
326+
then.status(200).body(serde_json::to_vec(&json!({"bouncer_weights": builtin_weights})).unwrap());
327+
});
328+
309329
let services = TestConfigBuilder::new()
310330
.configure_starknet_client(provider.into())
331+
.configure_madara_feeder_gateway_url(&provider_url)
311332
.configure_database(database.into())
312333
.configure_storage_client(storage.into())
313334
.configure_prover_client(prover_client.into())

orchestrator/src/types/batch.rs

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use blockifier::bouncer::BouncerWeights;
12
use chrono::{DateTime, SubsecRound, Utc};
23
#[cfg(feature = "with_mongodb")]
34
use mongodb::bson::serde_helpers::{chrono_datetime_as_bson_datetime, uuid_1_as_binary};
@@ -66,6 +67,12 @@ pub struct SnosBatchUpdates {
6667
pub status: Option<SnosBatchStatus>,
6768
}
6869

70+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
71+
pub struct AggregatorBatchWeights {
72+
pub l1_gas: usize,
73+
pub message_segment_length: usize,
74+
}
75+
6976
/// Aggregator Batch
7077
///
7178
/// Represents a high-level batch that contains multiple SNOS batches and manages
@@ -139,13 +146,42 @@ pub struct AggregatorBatch {
139146
/// Used to track the batch in the proving system
140147
pub bucket_id: String,
141148

149+
/// Builtin weights for the batch. We decide when to close a batch based on this.
150+
pub builtin_weights: AggregatorBatchWeights,
151+
142152
/// Current status of the aggregator batch
143153
pub status: AggregatorBatchStatus,
144154
/// Starknet protocol version for all blocks in this batch
145155
/// All blocks in a batch must have the same Starknet version for prover compatibility
146156
pub starknet_version: String,
147157
}
148158

159+
impl AggregatorBatchWeights {
160+
pub fn new(l1_gas: usize, message_segment_length: usize) -> Self {
161+
Self { l1_gas, message_segment_length }
162+
}
163+
164+
pub fn checked_add(&self, other: &AggregatorBatchWeights) -> Option<AggregatorBatchWeights> {
165+
Some(Self {
166+
l1_gas: self.l1_gas.checked_add(other.l1_gas)?,
167+
message_segment_length: self.message_segment_length.checked_add(other.message_segment_length)?,
168+
})
169+
}
170+
171+
pub fn checked_sub(&self, other: &AggregatorBatchWeights) -> Option<AggregatorBatchWeights> {
172+
Some(Self {
173+
l1_gas: self.l1_gas.checked_sub(other.l1_gas)?,
174+
message_segment_length: self.message_segment_length.checked_sub(other.message_segment_length)?,
175+
})
176+
}
177+
}
178+
179+
impl From<&BouncerWeights> for AggregatorBatchWeights {
180+
fn from(weights: &BouncerWeights) -> Self {
181+
Self { l1_gas: weights.l1_gas, message_segment_length: weights.message_segment_length }
182+
}
183+
}
184+
149185
impl AggregatorBatch {
150186
/// Creates a new aggregator batch
151187
///
@@ -158,13 +194,15 @@ impl AggregatorBatch {
158194
///
159195
/// # Returns
160196
/// A new `AggregatorBatch` instance with status `Open` and single block
197+
#[allow(clippy::too_many_arguments)]
161198
pub fn new(
162199
index: u64,
163200
start_snos_batch: u64,
164201
start_block: u64,
165202
squashed_state_updates_path: String,
166203
blob_path: String,
167204
bucket_id: String,
205+
builtin_weights: AggregatorBatchWeights,
168206
starknet_version: String,
169207
) -> Self {
170208
Self {
@@ -184,6 +222,7 @@ impl AggregatorBatch {
184222
bucket_id,
185223
starknet_version,
186224
status: AggregatorBatchStatus::Open,
225+
builtin_weights,
187226
}
188227
}
189228
}
@@ -303,3 +342,143 @@ impl SnosBatch {
303342
Ok(())
304343
}
305344
}
345+
346+
#[cfg(test)]
347+
mod tests {
348+
use super::*;
349+
350+
mod aggregator_batch_weights_tests {
351+
use super::*;
352+
353+
#[test]
354+
fn test_new() {
355+
let weights = AggregatorBatchWeights::new(1000, 500);
356+
assert_eq!(weights.l1_gas, 1000);
357+
assert_eq!(weights.message_segment_length, 500);
358+
}
359+
360+
#[test]
361+
fn test_checked_add_success() {
362+
let weights1 = AggregatorBatchWeights::new(1000, 500);
363+
let weights2 = AggregatorBatchWeights::new(2000, 300);
364+
365+
let result = weights1.checked_add(&weights2);
366+
assert!(result.is_some());
367+
368+
let sum = result.unwrap();
369+
assert_eq!(sum.l1_gas, 3000);
370+
assert_eq!(sum.message_segment_length, 800);
371+
}
372+
373+
#[test]
374+
fn test_checked_add_overflow_l1_gas() {
375+
let weights1 = AggregatorBatchWeights::new(usize::MAX, 100);
376+
let weights2 = AggregatorBatchWeights::new(1, 100);
377+
378+
let result = weights1.checked_add(&weights2);
379+
assert!(result.is_none());
380+
}
381+
382+
#[test]
383+
fn test_checked_add_overflow_message_segment_length() {
384+
let weights1 = AggregatorBatchWeights::new(100, usize::MAX);
385+
let weights2 = AggregatorBatchWeights::new(100, 1);
386+
387+
let result = weights1.checked_add(&weights2);
388+
assert!(result.is_none());
389+
}
390+
391+
#[test]
392+
fn test_checked_add_max_values() {
393+
let weights1 = AggregatorBatchWeights::new(usize::MAX / 2, usize::MAX / 2);
394+
let weights2 = AggregatorBatchWeights::new(usize::MAX / 2, usize::MAX / 2);
395+
396+
let result = weights1.checked_add(&weights2);
397+
assert!(result.is_some());
398+
399+
let sum = result.unwrap();
400+
assert_eq!(sum.l1_gas, usize::MAX - 1);
401+
assert_eq!(sum.message_segment_length, usize::MAX - 1);
402+
}
403+
404+
#[test]
405+
fn test_checked_sub_success() {
406+
let weights1 = AggregatorBatchWeights::new(2000, 500);
407+
let weights2 = AggregatorBatchWeights::new(1000, 300);
408+
409+
let result = weights1.checked_sub(&weights2);
410+
assert!(result.is_some());
411+
412+
let diff = result.unwrap();
413+
assert_eq!(diff.l1_gas, 1000);
414+
assert_eq!(diff.message_segment_length, 200);
415+
}
416+
417+
#[test]
418+
fn test_checked_sub_with_zero() {
419+
let weights1 = AggregatorBatchWeights::new(1000, 500);
420+
let weights2 = AggregatorBatchWeights::new(0, 0);
421+
422+
let result = weights1.checked_sub(&weights2);
423+
assert!(result.is_some());
424+
425+
let diff = result.unwrap();
426+
assert_eq!(diff.l1_gas, 1000);
427+
assert_eq!(diff.message_segment_length, 500);
428+
}
429+
430+
#[test]
431+
fn test_checked_sub_equal_values() {
432+
let weights1 = AggregatorBatchWeights::new(1000, 500);
433+
let weights2 = AggregatorBatchWeights::new(1000, 500);
434+
435+
let result = weights1.checked_sub(&weights2);
436+
assert!(result.is_some());
437+
438+
let diff = result.unwrap();
439+
assert_eq!(diff.l1_gas, 0);
440+
assert_eq!(diff.message_segment_length, 0);
441+
}
442+
443+
#[test]
444+
fn test_checked_sub_underflow_l1_gas() {
445+
let weights1 = AggregatorBatchWeights::new(100, 500);
446+
let weights2 = AggregatorBatchWeights::new(200, 300);
447+
448+
let result = weights1.checked_sub(&weights2);
449+
assert!(result.is_none());
450+
}
451+
452+
#[test]
453+
fn test_checked_sub_underflow_message_segment_length() {
454+
let weights1 = AggregatorBatchWeights::new(500, 100);
455+
let weights2 = AggregatorBatchWeights::new(300, 200);
456+
457+
let result = weights1.checked_sub(&weights2);
458+
assert!(result.is_none());
459+
}
460+
461+
#[test]
462+
fn test_checked_sub_from_max() {
463+
let weights1 = AggregatorBatchWeights::new(usize::MAX, usize::MAX);
464+
let weights2 = AggregatorBatchWeights::new(1, 1);
465+
466+
let result = weights1.checked_sub(&weights2);
467+
assert!(result.is_some());
468+
469+
let diff = result.unwrap();
470+
assert_eq!(diff.l1_gas, usize::MAX - 1);
471+
assert_eq!(diff.message_segment_length, usize::MAX - 1);
472+
}
473+
474+
#[test]
475+
fn test_from_bouncer_weights() {
476+
let bouncer_weights =
477+
BouncerWeights { l1_gas: 1234, message_segment_length: usize::MAX, ..Default::default() };
478+
479+
let agg_weights = AggregatorBatchWeights::from(&bouncer_weights);
480+
assert_eq!(agg_weights.l1_gas, 1234);
481+
assert_eq!(agg_weights.message_segment_length, usize::MAX);
482+
}
483+
}
484+
}

0 commit comments

Comments
 (0)