Skip to content

Commit 89d0d3c

Browse files
authored
feat: add shutdown method to AsyncDB DB and Runner (#250)
* add shutdown interface & use Signed-off-by: Bugen Zhao <[email protected]> * fix postgres impl Signed-off-by: Bugen Zhao <[email protected]> * manual shutdown Signed-off-by: Bugen Zhao <[email protected]> * remove f word Signed-off-by: Bugen Zhao <[email protected]> * also add shutdown method to sync db Signed-off-by: Bugen Zhao <[email protected]> * bump version and add change log Signed-off-by: Bugen Zhao <[email protected]> --------- Signed-off-by: Bugen Zhao <[email protected]>
1 parent c3b8c52 commit 89d0d3c

File tree

15 files changed

+95
-33
lines changed

15 files changed

+95
-33
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
## [0.27.0] - 2025-02-11
11+
12+
* runner: add `shutdown` method to `DB` and `AsyncDB` trait to allow for graceful shutdown of the database connection. Users are encouraged to call `Runner::shutdown` or `Runner::shutdown_async` after running tests to ensure that the database connections are properly closed.
13+
1014
## [0.26.4] - 2025-01-27
1115

1216
* runner: add random string in path generation to avoid conflict when using `include`.

Cargo.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ resolver = "2"
33
members = ["sqllogictest", "sqllogictest-bin", "sqllogictest-engines", "tests"]
44

55
[workspace.package]
6-
version = "0.26.4"
6+
version = "0.27.0"
77
edition = "2021"
88
homepage = "https://github.com/risinglightdb/sqllogictest-rs"
99
keywords = ["sql", "database", "parser", "cli"]

sqllogictest-bin/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ glob = "0.3"
2323
itertools = "0.13"
2424
quick-junit = { version = "0.5" }
2525
rand = "0.8"
26-
sqllogictest = { path = "../sqllogictest", version = "0.26" }
27-
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.26" }
26+
sqllogictest = { path = "../sqllogictest", version = "0.27" }
27+
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.27" }
2828
tokio = { version = "1", features = [
2929
"rt",
3030
"rt-multi-thread",

sqllogictest-bin/src/engines.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,8 @@ impl AsyncDB for Engines {
154154
async fn run_command(command: std::process::Command) -> std::io::Result<std::process::Output> {
155155
Command::from(command).output().await
156156
}
157+
158+
async fn shutdown(&mut self) {
159+
dispatch_engines!(self, e, { e.shutdown().await })
160+
}
157161
}

sqllogictest-bin/src/main.rs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,9 @@ async fn run_parallel(
437437
}
438438
}
439439

440+
// Shutdown the connection for managing temporary databases.
441+
db.shutdown().await;
442+
440443
if !failed_case.is_empty() {
441444
Err(anyhow!("some test case failed:\n{:#?}", failed_case))
442445
} else {
@@ -467,7 +470,7 @@ async fn run_serial(
467470
let filename = file.to_string_lossy().to_string();
468471
let test_case_name = filename.replace(['/', ' ', '.', '-'], "_");
469472
let mut failed = false;
470-
let case = match run_test_file(&mut std::io::stdout(), runner, &file).await {
473+
let case = match run_test_file(&mut std::io::stdout(), &mut runner, &file).await {
471474
Ok(duration) => {
472475
let mut case = TestCase::new(test_case_name, TestCaseStatus::success());
473476
case.set_time(duration);
@@ -495,6 +498,7 @@ async fn run_serial(
495498
case
496499
}
497500
};
501+
runner.shutdown_async().await;
498502
test_suite.add_test_case(case);
499503
if connection_refused {
500504
eprintln!("Connection refused. The server may be down. Exiting...");
@@ -534,14 +538,16 @@ async fn update_test_files(
534538
format: bool,
535539
) -> Result<()> {
536540
for file in files {
537-
let runner = Runner::new(|| engines::connect(engine, &config));
541+
let mut runner = Runner::new(|| engines::connect(engine, &config));
538542

539-
if let Err(e) = update_test_file(&mut std::io::stdout(), runner, &file, format).await {
543+
if let Err(e) = update_test_file(&mut std::io::stdout(), &mut runner, &file, format).await {
540544
{
541545
println!("{}\n\n{:?}", style("[FAILED]").red().bold(), e);
542546
println!();
543547
}
544548
};
549+
550+
runner.shutdown_async().await;
545551
}
546552

547553
Ok(())
@@ -562,16 +568,17 @@ async fn connect_and_run_test_file(
562568
for label in labels {
563569
runner.add_label(label);
564570
}
565-
let result = run_test_file(out, runner, filename).await?;
571+
let result = run_test_file(out, &mut runner, filename).await;
572+
runner.shutdown_async().await;
566573

567-
Ok(result)
574+
result
568575
}
569576

570577
/// Different from [`Runner::run_file_async`], we re-implement it here to print some progress
571578
/// information.
572579
async fn run_test_file<T: std::io::Write, M: MakeConnection>(
573580
out: &mut T,
574-
mut runner: Runner<M::Conn, M>,
581+
runner: &mut Runner<M::Conn, M>,
575582
filename: impl AsRef<Path>,
576583
) -> Result<Duration> {
577584
let filename = filename.as_ref();
@@ -676,7 +683,7 @@ fn finish_test_file<T: std::io::Write>(
676683
/// progress information.
677684
async fn update_test_file<T: std::io::Write, M: MakeConnection>(
678685
out: &mut T,
679-
mut runner: Runner<M::Conn, M>,
686+
runner: &mut Runner<M::Conn, M>,
680687
filename: impl AsRef<Path>,
681688
format: bool,
682689
) -> Result<()> {
@@ -804,7 +811,7 @@ async fn update_test_file<T: std::io::Write, M: MakeConnection>(
804811
writeln!(outfile, "{record}")?;
805812
continue;
806813
}
807-
update_record(outfile, &mut runner, record, format)
814+
update_record(outfile, runner, record, format)
808815
.await
809816
.context(format!("failed to run `{}`", style(filename).bold()))?;
810817
}

sqllogictest-engines/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ postgres-types = { version = "0.2.8", features = ["derive", "with-chrono-0_4"] }
2020
rust_decimal = { version = "1.36.0", features = ["tokio-pg"] }
2121
serde = { version = "1", features = ["derive"] }
2222
serde_json = "1"
23-
sqllogictest = { path = "../sqllogictest", version = "0.26" }
23+
sqllogictest = { path = "../sqllogictest", version = "0.27" }
2424
thiserror = "2"
2525
tokio = { version = "1", features = [
2626
"rt",

sqllogictest-engines/src/external.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ impl AsyncDB for ExternalDriver {
113113
}
114114
}
115115

116+
async fn shutdown(&mut self) {
117+
self.stdin.shutdown().await.ok();
118+
self.child.wait().await.ok();
119+
}
120+
116121
fn engine_name(&self) -> &str {
117122
"external"
118123
}

sqllogictest-engines/src/mysql.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ impl sqllogictest::AsyncDB for MySql {
6666
}
6767
}
6868

69+
async fn shutdown(&mut self) {
70+
self.pool.clone().disconnect().await.ok();
71+
}
72+
6973
fn engine_name(&self) -> &str {
7074
"mysql"
7175
}

sqllogictest-engines/src/postgres.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ mod extended;
22
mod simple;
33

44
use std::marker::PhantomData;
5-
use std::sync::Arc;
65

76
use tokio::task::JoinHandle;
87

@@ -16,8 +15,8 @@ pub struct Extended;
1615
/// Generic Postgres engine based on the client from [`tokio_postgres`]. The protocol `P` can be
1716
/// either [`Simple`] or [`Extended`].
1817
pub struct Postgres<P> {
19-
client: Arc<tokio_postgres::Client>,
20-
join_handle: JoinHandle<()>,
18+
/// `None` means the connection is closed.
19+
conn: Option<(tokio_postgres::Client, JoinHandle<()>)>,
2120
_protocol: PhantomData<P>,
2221
}
2322

@@ -34,27 +33,28 @@ impl<P> Postgres<P> {
3433
pub async fn connect(config: PostgresConfig) -> Result<Self> {
3534
let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
3635

37-
let join_handle = tokio::spawn(async move {
36+
let connection = tokio::spawn(async move {
3837
if let Err(e) = connection.await {
3938
log::error!("Postgres connection error: {:?}", e);
4039
}
4140
});
4241

4342
Ok(Self {
44-
client: Arc::new(client),
45-
join_handle,
43+
conn: Some((client, connection)),
4644
_protocol: PhantomData,
4745
})
4846
}
4947

5048
/// Returns a reference of the inner Postgres client.
51-
pub fn pg_client(&self) -> &tokio_postgres::Client {
52-
&self.client
49+
pub fn client(&self) -> &tokio_postgres::Client {
50+
&self.conn.as_ref().expect("connection is shutdown").0
5351
}
54-
}
5552

56-
impl<P> Drop for Postgres<P> {
57-
fn drop(&mut self) {
58-
self.join_handle.abort()
53+
/// Shutdown the Postgres connection.
54+
async fn shutdown(&mut self) {
55+
if let Some((client, connection)) = self.conn.take() {
56+
drop(client);
57+
connection.await.ok();
58+
}
5959
}
6060
}

sqllogictest-engines/src/postgres/extended.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ macro_rules! array_process {
7777
match v {
7878
Some(v) => {
7979
let sql = format!("select ($1::{})::varchar", stringify!($ty_name));
80-
let tmp_rows = $self.client.query(&sql, &[&v]).await.unwrap();
80+
let tmp_rows = $self.client().query(&sql, &[&v]).await.unwrap();
8181
let value: &str = tmp_rows.get(0).unwrap().get(0);
8282
assert!(value.len() > 0);
8383
write!(output, "{}", value).unwrap();
@@ -128,7 +128,7 @@ macro_rules! single_process {
128128
match value {
129129
Some(value) => {
130130
let sql = format!("select ($1::{})::varchar", stringify!($ty_name));
131-
let tmp_rows = $self.client.query(&sql, &[&value]).await.unwrap();
131+
let tmp_rows = $self.client().query(&sql, &[&value]).await.unwrap();
132132
let value: &str = tmp_rows.get(0).unwrap().get(0);
133133
assert!(value.len() > 0);
134134
$row_vec.push(value.to_string());
@@ -188,9 +188,9 @@ impl sqllogictest::AsyncDB for Postgres<Extended> {
188188
async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>> {
189189
let mut output = vec![];
190190

191-
let stmt = self.client.prepare(sql).await?;
191+
let stmt = self.client().prepare(sql).await?;
192192
let rows = self
193-
.client
193+
.client()
194194
.query_raw(&stmt, std::iter::empty::<&(dyn ToSql + Sync)>())
195195
.await?;
196196

@@ -311,6 +311,10 @@ impl sqllogictest::AsyncDB for Postgres<Extended> {
311311
}
312312
}
313313

314+
async fn shutdown(&mut self) {
315+
self.shutdown().await;
316+
}
317+
314318
fn engine_name(&self) -> &str {
315319
"postgres-extended"
316320
}

sqllogictest-engines/src/postgres/simple.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ impl sqllogictest::AsyncDB for Postgres<Simple> {
2020
// and we have to follow the format given by the specific database (pg).
2121
// For example, postgres will output `t` as true and `f` as false,
2222
// thus we have to write `t`/`f` in the expected results.
23-
let rows = self.client.simple_query(sql).await?;
23+
let rows = self.client().simple_query(sql).await?;
2424
let mut cnt = 0;
2525
for row in rows {
2626
let mut row_vec = vec![];
@@ -62,6 +62,10 @@ impl sqllogictest::AsyncDB for Postgres<Simple> {
6262
}
6363
}
6464

65+
async fn shutdown(&mut self) {
66+
self.shutdown().await;
67+
}
68+
6569
fn engine_name(&self) -> &str {
6670
"postgres"
6771
}

sqllogictest/src/connection.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::collections::HashMap;
22
use std::future::IntoFuture;
33

4+
use futures::future::join_all;
45
use futures::Future;
56

67
use crate::{AsyncDB, Connection as ConnectionName, DBOutput};
@@ -68,4 +69,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Connections<D, M> {
6869
pub async fn run_default(&mut self, sql: &str) -> Result<DBOutput<D::ColumnType>, D::Error> {
6970
self.get(ConnectionName::Default).await?.run(sql).await
7071
}
72+
73+
/// Shutdown all connections.
74+
pub async fn shutdown_all(&mut self) {
75+
join_all(self.conns.values_mut().map(|conn| conn.shutdown())).await;
76+
}
7177
}

sqllogictest/src/harness.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,6 @@ macro_rules! harness {
3535
pub fn test(filename: impl AsRef<Path>, make_conn: impl MakeConnection) -> Result<(), Failed> {
3636
let mut tester = Runner::new(make_conn);
3737
tester.run_file(filename)?;
38+
tester.shutdown();
3839
Ok(())
3940
}

sqllogictest/src/runner.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ pub trait AsyncDB {
7272
/// Async run a SQL query and return the output.
7373
async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error>;
7474

75+
/// Shutdown the connection gracefully.
76+
async fn shutdown(&mut self);
77+
7578
/// Engine name of current database.
7679
fn engine_name(&self) -> &str {
7780
""
@@ -106,6 +109,9 @@ pub trait DB {
106109
/// Run a SQL query and return the output.
107110
fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error>;
108111

112+
/// Shutdown the connection gracefully.
113+
fn shutdown(&mut self) {}
114+
109115
/// Engine name of current database.
110116
fn engine_name(&self) -> &str {
111117
""
@@ -125,6 +131,10 @@ where
125131
D::run(self, sql)
126132
}
127133

134+
async fn shutdown(&mut self) {
135+
D::shutdown(self);
136+
}
137+
128138
fn engine_name(&self) -> &str {
129139
D::engine_name(self)
130140
}
@@ -512,7 +522,7 @@ pub fn strict_column_validator<T: ColumnType>(actual: &Vec<T>, expected: &Vec<T>
512522
}
513523

514524
/// Sqllogictest runner.
515-
pub struct Runner<D: AsyncDB, M: MakeConnection> {
525+
pub struct Runner<D: AsyncDB, M: MakeConnection<Conn = D>> {
516526
conn: Connections<D, M>,
517527
// validator is used for validate if the result of query equals to expected.
518528
validator: Validator,
@@ -1472,6 +1482,19 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
14721482
}
14731483
}
14741484

1485+
impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
1486+
/// Shutdown all connections in the runner.
1487+
pub async fn shutdown_async(&mut self) {
1488+
tracing::debug!("shutting down runner...");
1489+
self.conn.shutdown_all().await;
1490+
}
1491+
1492+
/// Shutdown all connections in the runner.
1493+
pub fn shutdown(&mut self) {
1494+
block_on(self.shutdown_async());
1495+
}
1496+
}
1497+
14751498
/// Updates the specified [`Record`] with the [`QueryOutput`] produced
14761499
/// by a Database, returning `Some(new_record)`.
14771500
///

0 commit comments

Comments
 (0)