Skip to content

Commit 04a9598

Browse files
authoredMar 7, 2025··
feat(runner, bin): support test case partitioning (#257)
1 parent a1e6957 commit 04a9598

File tree

7 files changed

+151
-22
lines changed

7 files changed

+151
-22
lines changed
 

‎CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
## [0.28.0] - 2025-03-06
11+
12+
* runner: Add `Partitioner` and `Runner::with_partitioner` to enable partitioning of test cases, allowing only a subset of the glob result to be executed. This can be helpful for running tests in parallel in CI.
13+
* bin: Add `--partition-id` and `--partition-count` to set a hash partitioning for the test cases. If users are running in Buildkite CI with `parallelism: ..` specified in the workflow file, this will be automatically configured.
14+
1015
## [0.27.2] - 2025-02-18
1116

1217
* engines/bin: fix stdin to be closed properly to avoid hangs in the `external` engine.

‎Cargo.lock

+3-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ resolver = "2"
33
members = ["sqllogictest", "sqllogictest-bin", "sqllogictest-engines", "tests"]
44

55
[workspace.package]
6-
version = "0.27.2"
6+
version = "0.28.0"
77
edition = "2021"
88
homepage = "https://github.com/risinglightdb/sqllogictest-rs"
99
keywords = ["sql", "database", "parser", "cli"]

‎sqllogictest-bin/Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ glob = "0.3"
2424
itertools = "0.13"
2525
quick-junit = { version = "0.5" }
2626
rand = "0.8"
27-
sqllogictest = { path = "../sqllogictest", version = "0.27" }
28-
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.27" }
27+
sqllogictest = { path = "../sqllogictest", version = "0.28" }
28+
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.28" }
2929
tokio = { version = "1", features = [
3030
"rt",
3131
"rt-multi-thread",

‎sqllogictest-bin/src/main.rs

+100-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod engines;
22

33
use std::collections::{BTreeMap, BTreeSet, HashSet};
4+
use std::hash::{DefaultHasher, Hash, Hasher};
45
use std::io::{stdout, Read, Seek, SeekFrom, Write};
56
use std::path::{Path, PathBuf};
67
use std::time::{Duration, Instant};
@@ -19,7 +20,7 @@ use rand::seq::SliceRandom;
1920
use sqllogictest::substitution::well_known;
2021
use sqllogictest::{
2122
default_column_validator, default_normalizer, default_validator, update_record_with_output,
22-
AsyncDB, Injected, MakeConnection, Record, Runner,
23+
AsyncDB, Injected, MakeConnection, Partitioner, Record, Runner,
2324
};
2425
use tokio_util::task::AbortOnDropHandle;
2526

@@ -32,6 +33,10 @@ pub enum Color {
3233
Never,
3334
}
3435

36+
// Env keys for partitioning.
37+
const PARTITION_ID_ENV_KEY: &str = "SLT_PARTITION_ID";
38+
const PARTITION_COUNT_ENV_KEY: &str = "SLT_PARTITION_COUNT";
39+
3540
#[derive(Parser, Debug, Clone)]
3641
#[clap(about, version, author)]
3742
struct Opt {
@@ -112,6 +117,18 @@ struct Opt {
112117
/// The engine name is a label by default.
113118
#[clap(long = "label")]
114119
labels: Vec<String>,
120+
121+
/// Partition ID for sharding the test files. When used with `partition_count`,
122+
/// divides the test files into shards based on the hash of the file path.
123+
///
124+
/// Useful for running tests in parallel across multiple CI jobs. Currently
125+
/// automatically configured in Buildkite.
126+
#[clap(long, env = PARTITION_ID_ENV_KEY)]
127+
partition_id: Option<u64>,
128+
129+
/// Total number of partitions for test sharding. More details in `partition_id`.
130+
#[clap(long, env = PARTITION_COUNT_ENV_KEY)]
131+
partition_count: Option<u64>,
115132
}
116133

117134
/// Connection configuration.
@@ -138,10 +155,62 @@ impl DBConfig {
138155
}
139156
}
140157

158+
struct HashPartitioner {
159+
count: u64,
160+
id: u64,
161+
}
162+
163+
impl HashPartitioner {
164+
fn new(count: u64, id: u64) -> Result<Self> {
165+
if count == 0 {
166+
bail!("partition count must be greater than zero");
167+
}
168+
if id >= count {
169+
bail!("partition id (zero-based) must be less than count");
170+
}
171+
Ok(Self { count, id })
172+
}
173+
}
174+
175+
impl Partitioner for HashPartitioner {
176+
fn matches(&self, file_name: &str) -> bool {
177+
let mut hasher = DefaultHasher::new();
178+
file_name.hash(&mut hasher);
179+
hasher.finish() % self.count == self.id
180+
}
181+
}
182+
183+
#[allow(clippy::needless_return)]
184+
fn import_partition_config_from_ci() {
185+
if std::env::var_os(PARTITION_ID_ENV_KEY).is_some()
186+
|| std::env::var_os(PARTITION_COUNT_ENV_KEY).is_some()
187+
{
188+
// Ignore if already set.
189+
return;
190+
}
191+
192+
// Buildkite
193+
{
194+
const ID: &str = "BUILDKITE_PARALLEL_JOB";
195+
const COUNT: &str = "BUILDKITE_PARALLEL_JOB_COUNT";
196+
197+
if let (Some(id), Some(count)) = (std::env::var_os(ID), std::env::var_os(COUNT)) {
198+
std::env::set_var(PARTITION_ID_ENV_KEY, id);
199+
std::env::set_var(PARTITION_COUNT_ENV_KEY, count);
200+
eprintln!("Imported partition config from Buildkite.");
201+
return;
202+
}
203+
}
204+
205+
// TODO: more CI providers
206+
}
207+
141208
#[tokio::main]
142209
pub async fn main() -> Result<()> {
143210
tracing_subscriber::fmt::init();
144211

212+
import_partition_config_from_ci();
213+
145214
let cli = Opt::command().disable_help_flag(true).arg(
146215
Arg::new("help")
147216
.long("help")
@@ -167,6 +236,8 @@ pub async fn main() -> Result<()> {
167236
r#override,
168237
format,
169238
labels,
239+
partition_count,
240+
partition_id,
170241
} = Opt::from_arg_matches(&matches)
171242
.map_err(|err| err.exit())
172243
.unwrap();
@@ -205,17 +276,34 @@ pub async fn main() -> Result<()> {
205276
Color::Auto => {}
206277
}
207278

279+
let partitioner = if let Some(count) = partition_count {
280+
let id = partition_id.context("parallel job count is specified but job id is not")?;
281+
Some(HashPartitioner::new(count, id)?)
282+
} else {
283+
None
284+
};
285+
208286
let glob_patterns = files;
209-
let mut files: Vec<PathBuf> = Vec::new();
210-
for glob_pattern in glob_patterns.into_iter() {
211-
let pathbufs = glob::glob(&glob_pattern).context("failed to read glob pattern")?;
212-
for pathbuf in pathbufs.into_iter().try_collect::<_, Vec<_>, _>()? {
213-
files.push(pathbuf)
287+
let mut all_files = Vec::new();
288+
289+
for glob_pattern in glob_patterns {
290+
let mut files: Vec<PathBuf> = glob::glob(&glob_pattern)
291+
.context("failed to read glob pattern")?
292+
.try_collect()?;
293+
294+
// Test against partitioner only if there are multiple files matched, e.g., expanded from an `*`.
295+
if files.len() > 1 {
296+
if let Some(partitioner) = &partitioner {
297+
let len = files.len();
298+
files.retain(|path| partitioner.matches(path.to_str().unwrap()));
299+
let len_after = files.len();
300+
eprintln!(
301+
"Running {len_after} out of {len} test cases for glob pattern \"{glob_pattern}\" based on partitioning.",
302+
);
303+
}
214304
}
215-
}
216305

217-
if files.is_empty() {
218-
bail!("no test case found");
306+
all_files.extend(files);
219307
}
220308

221309
let config = DBConfig {
@@ -227,7 +315,7 @@ pub async fn main() -> Result<()> {
227315
};
228316

229317
if r#override || format {
230-
return update_test_files(files, &engine, config, format).await;
318+
return update_test_files(all_files, &engine, config, format).await;
231319
}
232320

233321
let mut report = Report::new(junit.clone().unwrap_or_else(|| "sqllogictest".to_string()));
@@ -241,7 +329,7 @@ pub async fn main() -> Result<()> {
241329
jobs,
242330
keep_db_on_failure,
243331
&mut test_suite,
244-
files,
332+
all_files,
245333
&engine,
246334
config,
247335
&labels,
@@ -252,7 +340,7 @@ pub async fn main() -> Result<()> {
252340
} else {
253341
run_serial(
254342
&mut test_suite,
255-
files,
343+
all_files,
256344
&engine,
257345
config,
258346
&labels,

‎sqllogictest-engines/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ postgres-types = { version = "0.2.8", features = ["derive", "with-chrono-0_4"] }
2020
rust_decimal = { version = "1.36.0", features = ["tokio-pg"] }
2121
serde = { version = "1", features = ["derive"] }
2222
serde_json = "1"
23-
sqllogictest = { path = "../sqllogictest", version = "0.27" }
23+
sqllogictest = { path = "../sqllogictest", version = "0.28" }
2424
thiserror = "2"
2525
tokio = { version = "1", features = [
2626
"rt",

‎sqllogictest/src/runner.rs

+39-3
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,27 @@ pub fn strict_column_validator<T: ColumnType>(actual: &Vec<T>, expected: &Vec<T>
522522
.any(|(actual_column, expected_column)| actual_column != expected_column)
523523
}
524524

525+
/// Decide whether a test file should be run. Useful for partitioning tests into multiple
526+
/// parallel machines to speed up test runs.
527+
pub trait Partitioner: Send + Sync + 'static {
528+
/// Returns true if the given file name matches the partition and should be run.
529+
fn matches(&self, file_name: &str) -> bool;
530+
}
531+
532+
impl<F> Partitioner for F
533+
where
534+
F: Fn(&str) -> bool + Send + Sync + 'static,
535+
{
536+
fn matches(&self, file_name: &str) -> bool {
537+
self(file_name)
538+
}
539+
}
540+
541+
/// The default partitioner matches all files.
542+
pub fn default_partitioner(_file_name: &str) -> bool {
543+
true
544+
}
545+
525546
#[derive(Default)]
526547
pub(crate) struct RunnerLocals {
527548
/// The temporary directory. Test cases can use `__TEST_DIR__` to refer to this directory.
@@ -560,6 +581,7 @@ pub struct Runner<D: AsyncDB, M: MakeConnection<Conn = D>> {
560581
// normalizer is used to normalize the result text
561582
normalizer: Normalizer,
562583
column_type_validator: ColumnTypeValidator<D::ColumnType>,
584+
partitioner: Arc<dyn Partitioner>,
563585
substitution_on: bool,
564586
sort_mode: Option<SortMode>,
565587
result_mode: Option<ResultMode>,
@@ -580,6 +602,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
580602
validator: default_validator,
581603
normalizer: default_normalizer,
582604
column_type_validator: default_column_validator,
605+
partitioner: Arc::new(default_partitioner),
583606
substitution_on: false,
584607
sort_mode: None,
585608
result_mode: None,
@@ -611,6 +634,13 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
611634
self.column_type_validator = validator;
612635
}
613636

637+
/// Set the partitioner for the runner. Only files that match the partitioner will be run.
638+
///
639+
/// This only takes effect when running tests in parallel.
640+
pub fn with_partitioner(&mut self, partitioner: impl Partitioner + 'static) {
641+
self.partitioner = Arc::new(partitioner);
642+
}
643+
614644
pub fn with_hash_threshold(&mut self, hash_threshold: usize) {
615645
self.hash_threshold = hash_threshold;
616646
}
@@ -1281,13 +1311,18 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
12811311
{
12821312
let files = glob::glob(glob).expect("failed to read glob pattern");
12831313
let mut tasks = vec![];
1284-
// let conn_builder = Arc::new(conn_builder);
12851314

12861315
for (idx, file) in files.enumerate() {
12871316
// for every slt file, we create a database against table conflict
12881317
let file = file.unwrap();
1289-
let db_name = file.to_str().expect("not a UTF-8 filename");
1290-
let db_name = db_name.replace([' ', '.', '-', '/'], "_");
1318+
let filename = file.to_str().expect("not a UTF-8 filename");
1319+
1320+
// Skip files that don't match the partitioner.
1321+
if !self.partitioner.matches(filename) {
1322+
continue;
1323+
}
1324+
1325+
let db_name = filename.replace([' ', '.', '-', '/'], "_");
12911326

12921327
self.conn
12931328
.run_default(&format!("CREATE DATABASE {db_name};"))
@@ -1305,6 +1340,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
13051340
validator: self.validator,
13061341
normalizer: self.normalizer,
13071342
column_type_validator: self.column_type_validator,
1343+
partitioner: self.partitioner.clone(),
13081344
substitution_on: self.substitution_on,
13091345
sort_mode: self.sort_mode,
13101346
result_mode: self.result_mode,

0 commit comments

Comments
 (0)
Please sign in to comment.