Skip to content

Commit 184f27f

Browse files
committed
fix/core: use CAS to protect against concurrent TX state changes
Concurrent use a of a connection across multiple threads is not supported, but we should not crash even if a client attempts to do so. We do crash when this happens. See: #3911 The proper behavior is to return Busy if the same connection is already doing transactional work on another thread. To ensure the above, this PR implements Connection::atomic_swap_tx_state() which uses CAS instead of simply setting the state. This PR also adds a regression test modeled after the reproduction in #3911 that only accepts `"database is locked"` errors and panics on anything else.
1 parent caa71ea commit 184f27f

File tree

6 files changed

+145
-11
lines changed

6 files changed

+145
-11
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

bindings/rust/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ thiserror = { workspace = true }
2222
tracing-subscriber.workspace = true
2323
tracing.workspace = true
2424
mimalloc = { workspace = true, optional = true }
25+
rusqlite.workspace = true
2526

2627
[dev-dependencies]
2728
tempfile = { workspace = true }

bindings/rust/tests/integration_tests.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::time::Duration;
2+
13
use tokio::fs;
24
use turso::{Builder, EncryptionOpts, Error, Value};
35

@@ -531,3 +533,74 @@ async fn test_connection_clone() {
531533
let id: i64 = row.get(0).unwrap();
532534
assert_eq!(id, 1);
533535
}
536+
537+
use tokio::task::JoinSet;
538+
539+
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
540+
/// Sharing a connection across threads is not supported, but we must not crash.
541+
/// The proper behavior is to return Busy to the caller if another statement is already running a transaction.
542+
/// This test verifies that we end up inserting at least some of the rows and that integrity check passes.
543+
async fn concurrent_inserts_on_shared_connection() {
544+
let temp_file = tempfile::NamedTempFile::new().expect("create temp file");
545+
let db_path = temp_file.path().to_str().expect("path to string");
546+
547+
let db = Builder::new_local(db_path)
548+
.build()
549+
.await
550+
.expect("temp file db");
551+
let conn = db.connect().expect("connect");
552+
conn.execute("CREATE TABLE IF NOT EXISTS t (value INTEGER)", ())
553+
.await
554+
.expect("create table");
555+
556+
conn.busy_timeout(Duration::from_millis(1000)).unwrap();
557+
558+
let attempts = 200usize;
559+
let mut join_set = JoinSet::new();
560+
561+
for i in 0..attempts {
562+
let conn = conn.clone();
563+
join_set.spawn(async move {
564+
let res = conn
565+
.execute("INSERT INTO t (value) VALUES (?1)", [i as i64])
566+
.await;
567+
match res {
568+
Ok(_) => {}
569+
Err(Error::SqlExecutionFailure(e)) if e.contains("database is locked") => {}
570+
Err(e) => panic!("unexpected error: {e:?}"),
571+
}
572+
});
573+
}
574+
575+
while let Some(res) = join_set.join_next().await {
576+
res.expect("task panicked");
577+
}
578+
579+
let mut rows = conn
580+
.query("SELECT COUNT(*) FROM t", ())
581+
.await
582+
.expect("count rows");
583+
let count_value = rows
584+
.next()
585+
.await
586+
.expect("step result")
587+
.expect("row")
588+
.get_value(0)
589+
.expect("count value");
590+
match count_value {
591+
Value::Integer(count) => assert!(count > 0),
592+
other => panic!("expected integer count, got {other:?}"),
593+
}
594+
assert!(rows.next().await.expect("consume rows").is_none());
595+
596+
// Close turso connection before opening with rusqlite
597+
drop(conn);
598+
drop(db);
599+
600+
// Verify database integrity with rusqlite
601+
let rusqlite_conn = rusqlite::Connection::open(db_path).expect("open with rusqlite");
602+
let integrity_result: String = rusqlite_conn
603+
.query_row("PRAGMA integrity_check", [], |row| row.get(0))
604+
.expect("integrity check");
605+
assert_eq!(integrity_result, "ok", "database integrity check failed");
606+
}

core/lib.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2623,6 +2623,16 @@ impl Connection {
26232623
self.transaction_state.get()
26242624
}
26252625

2626+
/// Atomically compare and exchange transaction state.
2627+
/// Returns Ok(current) on success, or Err(actual_value) if the state was modified concurrently.
2628+
fn atomic_swap_tx_state(
2629+
&self,
2630+
current: TransactionState,
2631+
new: TransactionState,
2632+
) -> Result<TransactionState, TransactionState> {
2633+
self.transaction_state.compare_exchange(current, new)
2634+
}
2635+
26262636
pub(crate) fn get_mv_tx_id(&self) -> Option<u64> {
26272637
self.mv_tx.read().map(|(tx_id, _)| tx_id)
26282638
}

core/vdbe/execute.rs

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2234,6 +2234,22 @@ pub fn op_transaction_inner(
22342234

22352235
// 1. We try to upgrade current version
22362236
let current_state = conn.get_tx_state();
2237+
let auto_commit = conn.auto_commit.load(Ordering::SeqCst);
2238+
2239+
// In autocommit mode, each statement is its own transaction. If we see
2240+
// a transaction already in progress (state != None), AND we haven't started
2241+
// any transaction work in this execution yet, it means another concurrent
2242+
// execution owns that transaction. Return Busy to avoid using their transaction.
2243+
//
2244+
// We check auto_txn_cleanup to distinguish:
2245+
// - RollbackTxn: We started a transaction, so state != None is expected (e.g., Read -> Write upgrade)
2246+
// - None: We haven't started anything, so state != None means someone else has the transaction
2247+
let connection_already_has_tx = !matches!(current_state, TransactionState::None)
2248+
&& matches!(state.auto_txn_cleanup, TxnCleanup::None);
2249+
if auto_commit && !conn.is_nested_stmt() && connection_already_has_tx {
2250+
return Err(LimboError::Busy);
2251+
}
2252+
22372253
let (new_transaction_state, updated) = if conn.is_nested_stmt() {
22382254
(current_state, false)
22392255
} else {
@@ -2275,7 +2291,22 @@ pub fn op_transaction_inner(
22752291
}
22762292
};
22772293

2278-
// 2. Start transaction if needed
2294+
// 2. Atomically claim the new transaction state BEFORE doing pager operations.
2295+
// This prevents races where two threads both think they can upgrade the state.
2296+
// If compare_exchange fails, another thread modified the state concurrently -
2297+
// return Busy to signal the caller should retry. We can't just retry here because
2298+
// pager transaction state is per-connection, and another thread's transaction
2299+
// is not ours to use.
2300+
if updated {
2301+
if conn
2302+
.atomic_swap_tx_state(current_state, new_transaction_state)
2303+
.is_err()
2304+
{
2305+
return Err(LimboError::Busy);
2306+
}
2307+
}
2308+
2309+
// 3. Start transaction on the pager. If this fails, restore the old state.
22792310
if let Some(mv_store) = mv_store.as_ref() {
22802311
// In MVCC we don't have write exclusivity, therefore we just need to start a transaction if needed.
22812312
// Programs can run Transaction twice, first with read flag and then with write flag. So a single txid is enough
@@ -2286,6 +2317,10 @@ pub fn op_transaction_inner(
22862317
let conn_has_executed_begin_deferred = !has_existing_mv_tx
22872318
&& !program.connection.auto_commit.load(Ordering::SeqCst);
22882319
if conn_has_executed_begin_deferred && *tx_mode == TransactionMode::Concurrent {
2320+
// Restore state before returning error
2321+
if updated {
2322+
conn.set_tx_state(current_state);
2323+
}
22892324
return Err(LimboError::TxError(
22902325
"Cannot start CONCURRENT transaction after BEGIN DEFERRED".to_string(),
22912326
));
@@ -2318,6 +2353,10 @@ pub fn op_transaction_inner(
23182353
}
23192354
} else {
23202355
if matches!(tx_mode, TransactionMode::Concurrent) {
2356+
// Restore state before returning error
2357+
if updated {
2358+
conn.set_tx_state(current_state);
2359+
}
23212360
return Err(LimboError::TxError(
23222361
"Concurrent transaction mode is only supported when MVCC is enabled"
23232362
.to_string(),
@@ -2344,27 +2383,20 @@ pub fn op_transaction_inner(
23442383
// start a new one.
23452384
if matches!(current_state, TransactionState::None) {
23462385
pager.end_read_tx();
2347-
conn.set_tx_state(TransactionState::None);
23482386
state.auto_txn_cleanup = TxnCleanup::None;
23492387
}
2350-
assert_eq!(conn.get_tx_state(), current_state);
2388+
// Restore the transaction state since pager ops failed
2389+
conn.set_tx_state(current_state);
23512390
return Err(LimboError::Busy);
23522391
}
23532392
if let IOResult::IO(io) = begin_w_tx_res? {
23542393
// set the transaction state to pending so we don't have to
23552394
// end the read transaction.
2356-
program
2357-
.connection
2358-
.set_tx_state(TransactionState::PendingUpgrade);
2395+
conn.set_tx_state(TransactionState::PendingUpgrade);
23592396
return Ok(InsnFunctionStepResult::IO(io));
23602397
}
23612398
}
23622399
}
2363-
2364-
// 3. Transaction state should be updated before checking for Schema cookie so that the tx is ended properly on error
2365-
if updated {
2366-
conn.set_tx_state(new_transaction_state);
2367-
}
23682400
state.op_transaction_state = OpTransactionState::CheckSchemaCookie;
23692401
continue;
23702402
}

macros/src/atomic_enum.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,23 @@ pub(crate) fn derive_atomic_enum_inner(input: TokenStream) -> TokenStream {
263263
let prev = self.0.swap(Self::to_storage(&val), ::std::sync::atomic::Ordering::SeqCst);
264264
Self::from_storage(prev)
265265
}
266+
267+
#[inline]
268+
/// Compare and exchange: if the current value equals `current`, replace it with `new`.
269+
/// Returns Ok(current) on success, or Err(actual_value) on failure.
270+
pub fn compare_exchange(&self, current: #name, new: #name) -> Result<#name, #name> {
271+
let current_storage = Self::to_storage(&current);
272+
let new_storage = Self::to_storage(&new);
273+
match self.0.compare_exchange(
274+
current_storage,
275+
new_storage,
276+
::std::sync::atomic::Ordering::SeqCst,
277+
::std::sync::atomic::Ordering::SeqCst,
278+
) {
279+
Ok(prev) => Ok(Self::from_storage(prev)),
280+
Err(actual) => Err(Self::from_storage(actual)),
281+
}
282+
}
266283
}
267284

268285
impl From<#name> for #atomic_name {

0 commit comments

Comments
 (0)