Skip to content

Commit 610d8cc

Browse files
authored
Merge 'introduce program execution state in order to run stmt to completion in case of finalize or reset' from Nikita Sivukhin
This PR introduces program execution state in order for statement to be aware of its state - is it terminal (Done, Failed, Interrupted) or not. The particular problem right now is that statements like `INSERT INTO t VALUES (1), (2), (3) RETURNING x` will execute inserts one by one and interleave them with rows generation. This means that if statement consumer will just read one row and then finalize the statement - nothing will be actually committed (because transaction will be aborted). In order to quickly mitigate this issue - program state is introduced which can help to decide what to do in the finalize. Reviewed-by: Jussi Saurio <[email protected]> Closes #4038
2 parents d1aa610 + 1bce053 commit 610d8cc

File tree

4 files changed

+137
-16
lines changed

4 files changed

+137
-16
lines changed

core/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ mod util;
3131
#[cfg(feature = "uuid")]
3232
mod uuid;
3333
mod vdbe;
34+
pub type ProgramExecutionState = vdbe::ProgramExecutionState;
3435
pub mod vector;
3536
mod vtab;
3637

@@ -2650,6 +2651,10 @@ impl Statement {
26502651
self.state.interrupt();
26512652
}
26522653

2654+
pub fn execution_state(&self) -> ProgramExecutionState {
2655+
self.state.execution_state
2656+
}
2657+
26532658
fn _step(&mut self, waker: Option<&Waker>) -> Result<StepResult> {
26542659
if let Some(busy_timeout) = self.busy_timeout.as_ref() {
26552660
if self.pager.io.now() < busy_timeout.timeout {

core/vdbe/mod.rs

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,39 @@ pub enum TxnCleanup {
273273
RollbackTxn,
274274
}
275275

276+
#[derive(Debug, Clone, Copy, PartialEq)]
277+
pub enum ProgramExecutionState {
278+
/// No steps of the program was executed
279+
Init,
280+
/// Program started execution but didn't reach any terminal state
281+
Running,
282+
/// Interrupt requested for the program
283+
Interrupting,
284+
/// Terminal state: program interrupted
285+
Interrupted,
286+
/// Terminal state: program finished successfully
287+
Done,
288+
/// Terminal state: program failed with error
289+
Failed,
290+
}
291+
292+
impl ProgramExecutionState {
293+
pub fn is_running(&self) -> bool {
294+
matches!(
295+
self,
296+
ProgramExecutionState::Interrupting | ProgramExecutionState::Running
297+
)
298+
}
299+
pub fn is_terminal(&self) -> bool {
300+
matches!(
301+
self,
302+
ProgramExecutionState::Interrupted
303+
| ProgramExecutionState::Failed
304+
| ProgramExecutionState::Done
305+
)
306+
}
307+
}
308+
276309
/// The program state describes the environment in which the program executes.
277310
pub struct ProgramState {
278311
pub io_completions: Option<IOCompletions>,
@@ -287,7 +320,7 @@ pub struct ProgramState {
287320
/// Indicate whether an [Insn::Once] instruction at a given program counter position has already been executed, well, once.
288321
once: SmallVec<u32, 4>,
289322
regex_cache: RegexCache,
290-
interrupted: bool,
323+
pub execution_state: ProgramExecutionState,
291324
pub parameters: HashMap<NonZero<usize>, Value>,
292325
commit_state: CommitState,
293326
#[cfg(feature = "json")]
@@ -356,7 +389,7 @@ impl ProgramState {
356389
ended_coroutine: Bitfield::new(),
357390
once: SmallVec::<u32, 4>::new(),
358391
regex_cache: RegexCache::new(),
359-
interrupted: false,
392+
execution_state: ProgramExecutionState::Init,
360393
parameters: HashMap::new(),
361394
commit_state: CommitState::Ready,
362395
#[cfg(feature = "json")]
@@ -409,11 +442,7 @@ impl ProgramState {
409442
}
410443

411444
pub fn interrupt(&mut self) {
412-
self.interrupted = true;
413-
}
414-
415-
pub fn is_interrupted(&self) -> bool {
416-
self.interrupted
445+
self.execution_state = ProgramExecutionState::Interrupting;
417446
}
418447

419448
pub fn bind_at(&mut self, index: NonZero<usize>, value: Value) {
@@ -450,7 +479,7 @@ impl ProgramState {
450479
self.deferred_seeks.iter_mut().for_each(|s| *s = None);
451480
self.ended_coroutine.0 = [0; 4];
452481
self.regex_cache.like.clear();
453-
self.interrupted = false;
482+
self.execution_state = ProgramExecutionState::Init;
454483
self.current_collation = None;
455484
#[cfg(feature = "json")]
456485
self.json_cache.clear();
@@ -650,11 +679,25 @@ impl Program {
650679
query_mode: QueryMode,
651680
waker: Option<&Waker>,
652681
) -> Result<StepResult> {
653-
match query_mode {
682+
state.execution_state = ProgramExecutionState::Running;
683+
let result = match query_mode {
654684
QueryMode::Normal => self.normal_step(state, mv_store, pager, waker),
655685
QueryMode::Explain => self.explain_step(state, mv_store, pager),
656686
QueryMode::ExplainQueryPlan => self.explain_query_plan_step(state, mv_store, pager),
687+
};
688+
match &result {
689+
Ok(StepResult::Done) => {
690+
state.execution_state = ProgramExecutionState::Done;
691+
}
692+
Ok(StepResult::Interrupt) => {
693+
state.execution_state = ProgramExecutionState::Interrupted;
694+
}
695+
Err(_) => {
696+
state.execution_state = ProgramExecutionState::Failed;
697+
}
698+
_ => {}
657699
}
700+
result
658701
}
659702

660703
fn explain_step(
@@ -673,7 +716,7 @@ impl Program {
673716
return Err(LimboError::InternalError("Connection closed".to_string()));
674717
}
675718

676-
if state.is_interrupted() {
719+
if matches!(state.execution_state, ProgramExecutionState::Interrupting) {
677720
return Ok(StepResult::Interrupt);
678721
}
679722

@@ -823,7 +866,7 @@ impl Program {
823866
return Err(LimboError::InternalError("Connection closed".to_string()));
824867
}
825868

826-
if state.is_interrupted() {
869+
if matches!(state.execution_state, ProgramExecutionState::Interrupting) {
827870
return Ok(StepResult::Interrupt);
828871
}
829872

@@ -871,10 +914,11 @@ impl Program {
871914
}
872915
return Err(LimboError::InternalError("Connection closed".to_string()));
873916
}
874-
if state.is_interrupted() {
917+
if matches!(state.execution_state, ProgramExecutionState::Interrupting) {
875918
self.abort(mv_store, &pager, None, state);
876919
return Ok(StepResult::Interrupt);
877920
}
921+
878922
if let Some(io) = &state.io_completions {
879923
if !io.finished() {
880924
io.set_waker(waker);

sqlite3/src/lib.rs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,32 @@ pub unsafe extern "C" fn sqlite3_prepare_v2(
325325
SQLITE_OK
326326
}
327327

328+
unsafe fn stmt_run_to_completion(stmt: *mut sqlite3_stmt) -> ffi::c_int {
329+
let stmt_ref = &mut *stmt;
330+
while stmt_ref.stmt.execution_state().is_running() {
331+
let result = sqlite3_step(stmt);
332+
if result != SQLITE_DONE && result != SQLITE_ROW {
333+
return result;
334+
}
335+
}
336+
SQLITE_OK
337+
}
338+
328339
#[no_mangle]
329340
pub unsafe extern "C" fn sqlite3_finalize(stmt: *mut sqlite3_stmt) -> ffi::c_int {
330341
if stmt.is_null() {
331342
return SQLITE_MISUSE;
332343
}
333344
let stmt_ref = &mut *stmt;
334345

346+
// first, finalize any execution if it was unfinished
347+
// (for example, many drivers can consume just one row and finalize statement after that, while there still can be work to do)
348+
// (this is necessary because queries like INSERT INTO t VALUES (1), (2), (3) RETURNING id return values within a transaction)
349+
let result = stmt_run_to_completion(stmt);
350+
if result != SQLITE_OK {
351+
return result;
352+
}
353+
335354
if !stmt_ref.db.is_null() {
336355
let db = &mut *stmt_ref.db;
337356
let mut db_inner = db.inner.lock().unwrap();
@@ -669,6 +688,13 @@ fn split_sql_statements(sql: &str) -> Vec<&str> {
669688
#[no_mangle]
670689
pub unsafe extern "C" fn sqlite3_reset(stmt: *mut sqlite3_stmt) -> ffi::c_int {
671690
let stmt = &mut *stmt;
691+
// first, finalize any execution if it was unfinished
692+
// (for example, many drivers can consume just one row and finalize statement after that, while there still can be work to do)
693+
// (this is necessary because queries like INSERT INTO t VALUES (1), (2), (3) RETURNING id return values within a transaction)
694+
let result = stmt_run_to_completion(stmt);
695+
if result != SQLITE_OK {
696+
return result;
697+
}
672698
stmt.stmt.reset();
673699
stmt.clear_text_cache();
674700
SQLITE_OK
@@ -1420,10 +1446,10 @@ unsafe extern "C" fn sqlite_get_table_cb(
14201446
for i in 0..n_column {
14211447
let value = *argv.add(i as usize);
14221448
let value_cstring = if !value.is_null() {
1423-
let len = libc::strlen(value);
1424-
let mut buf = Vec::with_capacity(len + 1);
1425-
libc::strncpy(buf.as_mut_ptr() as *mut ffi::c_char, value, len);
1426-
buf.set_len(len + 1);
1449+
let value_cstr = CStr::from_ptr(value).to_bytes();
1450+
let len = value_cstr.len();
1451+
let mut buf = vec![0u8; len + 1];
1452+
buf[0..len].copy_from_slice(value_cstr);
14271453
CString::from_vec_with_nul(buf).unwrap()
14281454
} else {
14291455
CString::new("NULL").unwrap()

sqlite3/tests/sqlite3_tests.c

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ void test_sqlite3_column_type();
2020
void test_sqlite3_column_decltype();
2121
void test_sqlite3_next_stmt();
2222
void test_sqlite3_table_column_metadata();
23+
void test_sqlite3_insert_returning();
2324

2425
int allocated = 0;
2526

@@ -39,6 +40,7 @@ int main(void)
3940
test_sqlite3_column_decltype();
4041
test_sqlite3_next_stmt();
4142
test_sqlite3_table_column_metadata();
43+
test_sqlite3_insert_returning();
4244
return 0;
4345
}
4446

@@ -746,3 +748,47 @@ void test_sqlite3_table_column_metadata()
746748
printf("sqlite3_table_column_metadata test passed\n");
747749
sqlite3_close(db);
748750
}
751+
752+
void test_sqlite3_insert_returning()
753+
{
754+
sqlite3 *db;
755+
sqlite3_stmt *stmt;
756+
char *err_msg = NULL;
757+
int rc;
758+
759+
rc = sqlite3_open(":memory:", &db);
760+
assert(rc == SQLITE_OK);
761+
762+
rc = sqlite3_exec(db,
763+
"CREATE TABLE t(x)",
764+
NULL, NULL, &err_msg);
765+
assert(rc == SQLITE_OK);
766+
767+
if (err_msg)
768+
{
769+
sqlite3_free(err_msg);
770+
err_msg = NULL;
771+
}
772+
rc = sqlite3_prepare_v2(db, "INSERT INTO t (x) VALUES (1), (2), (3) RETURNING x;", -1, &stmt, NULL);
773+
assert(rc == SQLITE_OK);
774+
775+
rc = sqlite3_step(stmt);
776+
assert(rc == SQLITE_ROW);
777+
778+
rc = sqlite3_finalize(stmt);
779+
assert(rc == SQLITE_OK);
780+
781+
rc = sqlite3_prepare_v2(db, "SELECT COUNT(*) FROM t;", -1, &stmt, NULL);
782+
assert(rc == SQLITE_OK);
783+
784+
rc = sqlite3_step(stmt);
785+
assert(rc == SQLITE_ROW);
786+
787+
sqlite_int64 fetched = sqlite3_column_int64(stmt, 0);
788+
assert(fetched == 3);
789+
790+
sqlite3_finalize(stmt);
791+
sqlite3_close(db);
792+
793+
printf("test_sqlite3_insert_retuning test passed\n");
794+
}

0 commit comments

Comments
 (0)