Skip to content

Commit 2003b1a

Browse files
authored
feat: support local variable substitution, add __DATABASE__ variable for bin (#253)
* support local variable substitution Signed-off-by: Richard Chien <[email protected]> * update changelog Signed-off-by: Richard Chien <[email protected]> --------- Signed-off-by: Richard Chien <[email protected]>
1 parent 89d0d3c commit 2003b1a

File tree

7 files changed

+118
-48
lines changed

7 files changed

+118
-48
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
* runner: Add `Runner::set_var` method to allow adding runner-local variables for substitution.
11+
* bin: Add `__DATABASE__` variable for accessing current database name from SLT files.
12+
1013
## [0.27.0] - 2025-02-11
1114

1215
* runner: add `shutdown` method to `DB` and `AsyncDB` trait to allow for graceful shutdown of the database connection. Users are encouraged to call `Runner::shutdown` or `Runner::shutdown_async` after running tests to ensure that the database connections are properly closed.

sqllogictest-bin/src/main.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use itertools::Itertools;
1616
use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite};
1717
use rand::distributions::DistString;
1818
use rand::seq::SliceRandom;
19+
use sqllogictest::substitution::well_known;
1920
use sqllogictest::{
2021
default_column_validator, default_normalizer, default_validator, update_record_with_output,
2122
AsyncDB, Injected, MakeConnection, Record, Runner,
@@ -466,6 +467,7 @@ async fn run_serial(
466467
for label in labels {
467468
runner.add_label(label);
468469
}
470+
runner.set_var(well_known::DATABASE.to_owned(), config.db.clone());
469471

470472
let filename = file.to_string_lossy().to_string();
471473
let test_case_name = filename.replace(['/', ' ', '.', '-'], "_");
@@ -539,6 +541,7 @@ async fn update_test_files(
539541
) -> Result<()> {
540542
for file in files {
541543
let mut runner = Runner::new(|| engines::connect(engine, &config));
544+
runner.set_var(well_known::DATABASE.to_owned(), config.db.clone());
542545

543546
if let Err(e) = update_test_file(&mut std::io::stdout(), &mut runner, &file, format).await {
544547
{
@@ -568,6 +571,7 @@ async fn connect_and_run_test_file(
568571
for label in labels {
569572
runner.add_label(label);
570573
}
574+
runner.set_var(well_known::DATABASE.to_owned(), config.db.clone());
571575
let result = run_test_file(out, &mut runner, filename).await;
572576
runner.shutdown_async().await;
573577

sqllogictest/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,9 @@ pub mod connection;
5757
pub mod harness;
5858
pub mod parser;
5959
pub mod runner;
60+
pub mod substitution;
6061

6162
pub use self::column_type::*;
6263
pub use self::connection::*;
6364
pub use self::parser::*;
6465
pub use self::runner::*;
65-
66-
mod substitution;

sqllogictest/src/runner.rs

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
//! Sqllogictest runner.
22
3-
use std::collections::HashSet;
3+
use std::collections::{BTreeMap, HashSet};
44
use std::fmt::{Debug, Display};
55
use std::path::Path;
66
use std::process::{Command, ExitStatus, Output};
7-
use std::sync::Arc;
7+
use std::sync::{Arc, OnceLock};
88
use std::time::Duration;
99
use std::vec;
1010

@@ -16,6 +16,7 @@ use md5::Digest;
1616
use owo_colors::OwoColorize;
1717
use rand::Rng;
1818
use similar::{Change, ChangeTag, TextDiff};
19+
use tempfile::TempDir;
1920

2021
use crate::parser::*;
2122
use crate::substitution::Substitution;
@@ -521,6 +522,36 @@ pub fn strict_column_validator<T: ColumnType>(actual: &Vec<T>, expected: &Vec<T>
521522
.any(|(actual_column, expected_column)| actual_column != expected_column)
522523
}
523524

525+
#[derive(Default)]
526+
pub(crate) struct RunnerLocals {
527+
/// The temporary directory. Test cases can use `__TEST_DIR__` to refer to this directory.
528+
/// Lazily initialized and cleaned up when dropped.
529+
test_dir: OnceLock<TempDir>,
530+
/// Runtime variables for substitution.
531+
variables: BTreeMap<String, String>,
532+
}
533+
534+
impl RunnerLocals {
535+
pub fn test_dir(&self) -> String {
536+
let test_dir = self
537+
.test_dir
538+
.get_or_init(|| TempDir::new().expect("failed to create testdir"));
539+
test_dir.path().to_string_lossy().into_owned()
540+
}
541+
542+
fn set_var(&mut self, key: String, value: String) {
543+
self.variables.insert(key, value);
544+
}
545+
546+
pub fn get_var(&self, key: &str) -> Option<&String> {
547+
self.variables.get(key)
548+
}
549+
550+
pub fn vars(&self) -> &BTreeMap<String, String> {
551+
&self.variables
552+
}
553+
}
554+
524555
/// Sqllogictest runner.
525556
pub struct Runner<D: AsyncDB, M: MakeConnection<Conn = D>> {
526557
conn: Connections<D, M>,
@@ -529,13 +560,15 @@ pub struct Runner<D: AsyncDB, M: MakeConnection<Conn = D>> {
529560
// normalizer is used to normalize the result text
530561
normalizer: Normalizer,
531562
column_type_validator: ColumnTypeValidator<D::ColumnType>,
532-
substitution: Option<Substitution>,
563+
substitution_on: bool,
533564
sort_mode: Option<SortMode>,
534565
result_mode: Option<ResultMode>,
535566
/// 0 means never hashing
536567
hash_threshold: usize,
537568
/// Labels for condition `skipif` and `onlyif`.
538569
labels: HashSet<String>,
570+
/// Local variables/context for the runner.
571+
locals: RunnerLocals,
539572
}
540573

541574
impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
@@ -547,12 +580,13 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
547580
validator: default_validator,
548581
normalizer: default_normalizer,
549582
column_type_validator: default_column_validator,
550-
substitution: None,
583+
substitution_on: false,
551584
sort_mode: None,
552585
result_mode: None,
553586
hash_threshold: 0,
554587
labels: HashSet::new(),
555588
conn: Connections::new(make_conn),
589+
locals: RunnerLocals::default(),
556590
}
557591
}
558592

@@ -561,6 +595,11 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
561595
self.labels.insert(label.to_string());
562596
}
563597

598+
/// Set a local variable for substitution.
599+
pub fn set_var(&mut self, key: String, value: String) {
600+
self.locals.set_var(key, value);
601+
}
602+
564603
pub fn with_normalizer(&mut self, normalizer: Normalizer) {
565604
self.normalizer = normalizer;
566605
}
@@ -862,11 +901,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
862901
Control::ResultMode(result_mode) => {
863902
self.result_mode = Some(result_mode);
864903
}
865-
Control::Substitution(on_off) => match (&mut self.substitution, on_off) {
866-
(s @ None, true) => *s = Some(Substitution::default()),
867-
(s @ Some(_), false) => *s = None,
868-
_ => {}
869-
},
904+
Control::Substitution(on_off) => self.substitution_on = on_off,
870905
}
871906

872907
RecordOutput::Nothing
@@ -1260,18 +1295,22 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
12601295
.expect("create db failed");
12611296
let target = hosts[idx % hosts.len()].clone();
12621297

1298+
let mut locals = RunnerLocals::default();
1299+
locals.set_var("__DATABASE__".to_owned(), db_name.clone());
1300+
12631301
let mut tester = Runner {
12641302
conn: Connections::new(move || {
12651303
conn_builder(target.clone(), db_name.clone()).map(Ok)
12661304
}),
12671305
validator: self.validator,
12681306
normalizer: self.normalizer,
12691307
column_type_validator: self.column_type_validator,
1270-
substitution: self.substitution.clone(),
1308+
substitution_on: self.substitution_on,
12711309
sort_mode: self.sort_mode,
12721310
result_mode: self.result_mode,
12731311
hash_threshold: self.hash_threshold,
12741312
labels: self.labels.clone(),
1313+
locals,
12751314
};
12761315

12771316
tasks.push(async move {
@@ -1317,9 +1356,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
13171356
/// This is useful for `system` commands: The shell can do the environment variables, and we can
13181357
/// write strings like `\n` without escaping.
13191358
fn may_substitute(&self, input: String, subst_env_vars: bool) -> Result<String, AnyError> {
1320-
if let Some(substitution) = &self.substitution {
1321-
substitution
1322-
.substitute(&input, subst_env_vars)
1359+
if self.substitution_on {
1360+
Substitution::new(&self.locals, subst_env_vars)
1361+
.substitute(&input)
13231362
.map_err(|e| Arc::new(e) as AnyError)
13241363
} else {
13251364
Ok(input)

sqllogictest/src/substitution.rs

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,75 @@
1-
use std::sync::{Arc, OnceLock};
2-
31
use subst::Env;
4-
use tempfile::{tempdir, TempDir};
2+
3+
use crate::RunnerLocals;
4+
5+
pub mod well_known {
6+
pub const TEST_DIR: &str = "__TEST_DIR__";
7+
pub const NOW: &str = "__NOW__";
8+
pub const DATABASE: &str = "__DATABASE__";
9+
}
510

611
/// Substitute environment variables and special variables like `__TEST_DIR__` in SQL.
7-
#[derive(Default, Clone)]
8-
pub(crate) struct Substitution {
9-
/// The temporary directory for `__TEST_DIR__`.
10-
/// Lazily initialized and cleaned up when dropped.
11-
test_dir: Arc<OnceLock<TempDir>>,
12+
pub(crate) struct Substitution<'a> {
13+
runner_locals: &'a RunnerLocals,
14+
subst_env_vars: bool,
15+
}
16+
17+
impl Substitution<'_> {
18+
pub fn new(runner_locals: &RunnerLocals, subst_env_vars: bool) -> Substitution {
19+
Substitution {
20+
runner_locals,
21+
subst_env_vars,
22+
}
23+
}
1224
}
1325

1426
#[derive(thiserror::Error, Debug)]
1527
#[error("substitution failed: {0}")]
1628
pub(crate) struct SubstError(subst::Error);
1729

18-
impl Substitution {
19-
pub fn substitute(&self, input: &str, subst_env_vars: bool) -> Result<String, SubstError> {
20-
if !subst_env_vars {
21-
Ok(input
22-
.replace("$__TEST_DIR__", &self.test_dir())
23-
.replace("$__NOW__", &self.now()))
24-
} else {
30+
fn now_string() -> String {
31+
std::time::SystemTime::now()
32+
.duration_since(std::time::UNIX_EPOCH)
33+
.expect("failed to get current time")
34+
.as_nanos()
35+
.to_string()
36+
}
37+
38+
impl Substitution<'_> {
39+
pub fn substitute(&self, input: &str) -> Result<String, SubstError> {
40+
if self.subst_env_vars {
2541
subst::substitute(input, self).map_err(SubstError)
42+
} else {
43+
Ok(self.simple_replace(input))
2644
}
2745
}
2846

29-
fn test_dir(&self) -> String {
30-
let test_dir = self
31-
.test_dir
32-
.get_or_init(|| tempdir().expect("failed to create testdir"));
33-
test_dir.path().to_string_lossy().into_owned()
34-
}
35-
36-
fn now(&self) -> String {
37-
std::time::SystemTime::now()
38-
.duration_since(std::time::UNIX_EPOCH)
39-
.expect("failed to get current time")
40-
.as_nanos()
41-
.to_string()
47+
fn simple_replace(&self, input: &str) -> String {
48+
let mut res = input
49+
.replace(
50+
&format!("${}", well_known::TEST_DIR),
51+
&self.runner_locals.test_dir(),
52+
)
53+
.replace(&format!("${}", well_known::NOW), &now_string());
54+
for (key, value) in self.runner_locals.vars() {
55+
res = res.replace(&format!("${}", key), value);
56+
}
57+
res
4258
}
4359
}
4460

45-
impl<'a> subst::VariableMap<'a> for Substitution {
61+
impl<'a> subst::VariableMap<'a> for Substitution<'a> {
4662
type Value = String;
4763

4864
fn get(&'a self, key: &str) -> Option<Self::Value> {
4965
match key {
50-
"__TEST_DIR__" => self.test_dir().into(),
51-
"__NOW__" => self.now().into(),
52-
key => Env.get(key),
66+
well_known::TEST_DIR => self.runner_locals.test_dir().into(),
67+
well_known::NOW => now_string().into(),
68+
key => self
69+
.runner_locals
70+
.get_var(key)
71+
.cloned()
72+
.or_else(|| Env.get(key)),
5373
}
5474
}
5575
}

tests/substitution/basic.slt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ path $__TEST_DIR__
2929
statement ok
3030
time $__NOW__
3131

32+
# a local variable set before running tester
33+
statement ok
34+
check $__DATABASE__
35+
3236
# non existent variables without default values are errors
3337
statement error No such variable
3438
check $NONEXISTENT_VARIABLE

tests/substitution/substitution.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use rusty_fork::rusty_fork_test;
2-
use sqllogictest::{DBOutput, DefaultColumnType};
2+
use sqllogictest::{substitution::well_known, DBOutput, DefaultColumnType};
33

44
pub struct FakeDB;
55

@@ -59,6 +59,7 @@ rusty_fork_test! {
5959
std::env::set_var("MY_PASSWORD", "rust");
6060

6161
let mut tester = sqllogictest::Runner::new(|| async { Ok(FakeDB) });
62+
tester.set_var(well_known::DATABASE.to_owned(), "fake_db".to_owned());
6263

6364
tester.run_file("./substitution/basic.slt").unwrap();
6465
}

0 commit comments

Comments
 (0)