Skip to content

Commit e39e60e

Browse files
committed
introduce program execution state in order to run stmt to completion in case of finalize or reset
1 parent cb3ae4d commit e39e60e

File tree

4 files changed

+115
-17
lines changed

4 files changed

+115
-17
lines changed

core/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ mod util;
3131
#[cfg(feature = "uuid")]
3232
mod uuid;
3333
mod vdbe;
34+
pub type ProgramExecutionState = vdbe::ProgramExecutionState;
3435
pub mod vector;
3536
mod vtab;
3637

@@ -2650,6 +2651,10 @@ impl Statement {
26502651
self.state.interrupt();
26512652
}
26522653

2654+
pub fn execution_state(&self) -> ProgramExecutionState {
2655+
self.state.execution_state
2656+
}
2657+
26532658
fn _step(&mut self, waker: Option<&Waker>) -> Result<StepResult> {
26542659
if let Some(busy_timeout) = self.busy_timeout.as_ref() {
26552660
if self.pager.io.now() < busy_timeout.timeout {

core/vdbe/mod.rs

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,16 @@ pub enum TxnCleanup {
273273
RollbackTxn,
274274
}
275275

276+
#[derive(Debug, Clone, Copy, PartialEq)]
277+
pub enum ProgramExecutionState {
278+
Init,
279+
Run,
280+
Interrupting,
281+
Interrupted,
282+
Done,
283+
Failed,
284+
}
285+
276286
/// The program state describes the environment in which the program executes.
277287
pub struct ProgramState {
278288
pub io_completions: Option<IOCompletions>,
@@ -287,7 +297,7 @@ pub struct ProgramState {
287297
/// Indicate whether an [Insn::Once] instruction at a given program counter position has already been executed, well, once.
288298
once: SmallVec<u32, 4>,
289299
regex_cache: RegexCache,
290-
interrupted: bool,
300+
pub execution_state: ProgramExecutionState,
291301
pub parameters: HashMap<NonZero<usize>, Value>,
292302
commit_state: CommitState,
293303
#[cfg(feature = "json")]
@@ -356,7 +366,7 @@ impl ProgramState {
356366
ended_coroutine: Bitfield::new(),
357367
once: SmallVec::<u32, 4>::new(),
358368
regex_cache: RegexCache::new(),
359-
interrupted: false,
369+
execution_state: ProgramExecutionState::Init,
360370
parameters: HashMap::new(),
361371
commit_state: CommitState::Ready,
362372
#[cfg(feature = "json")]
@@ -409,11 +419,7 @@ impl ProgramState {
409419
}
410420

411421
pub fn interrupt(&mut self) {
412-
self.interrupted = true;
413-
}
414-
415-
pub fn is_interrupted(&self) -> bool {
416-
self.interrupted
422+
self.execution_state = ProgramExecutionState::Interrupting;
417423
}
418424

419425
pub fn bind_at(&mut self, index: NonZero<usize>, value: Value) {
@@ -450,7 +456,7 @@ impl ProgramState {
450456
self.deferred_seeks.iter_mut().for_each(|s| *s = None);
451457
self.ended_coroutine.0 = [0; 4];
452458
self.regex_cache.like.clear();
453-
self.interrupted = false;
459+
self.execution_state = ProgramExecutionState::Init;
454460
self.current_collation = None;
455461
#[cfg(feature = "json")]
456462
self.json_cache.clear();
@@ -650,11 +656,25 @@ impl Program {
650656
query_mode: QueryMode,
651657
waker: Option<&Waker>,
652658
) -> Result<StepResult> {
653-
match query_mode {
659+
state.execution_state = ProgramExecutionState::Run;
660+
let result = match query_mode {
654661
QueryMode::Normal => self.normal_step(state, mv_store, pager, waker),
655662
QueryMode::Explain => self.explain_step(state, mv_store, pager),
656663
QueryMode::ExplainQueryPlan => self.explain_query_plan_step(state, mv_store, pager),
664+
};
665+
match &result {
666+
Ok(StepResult::Done) => {
667+
state.execution_state = ProgramExecutionState::Done;
668+
}
669+
Ok(StepResult::Interrupt) => {
670+
state.execution_state = ProgramExecutionState::Interrupted;
671+
}
672+
Err(_) => {
673+
state.execution_state = ProgramExecutionState::Failed;
674+
}
675+
_ => {}
657676
}
677+
result
658678
}
659679

660680
fn explain_step(
@@ -673,7 +693,7 @@ impl Program {
673693
return Err(LimboError::InternalError("Connection closed".to_string()));
674694
}
675695

676-
if state.is_interrupted() {
696+
if matches!(state.execution_state, ProgramExecutionState::Interrupting) {
677697
return Ok(StepResult::Interrupt);
678698
}
679699

@@ -823,7 +843,7 @@ impl Program {
823843
return Err(LimboError::InternalError("Connection closed".to_string()));
824844
}
825845

826-
if state.is_interrupted() {
846+
if matches!(state.execution_state, ProgramExecutionState::Interrupting) {
827847
return Ok(StepResult::Interrupt);
828848
}
829849

@@ -871,10 +891,11 @@ impl Program {
871891
}
872892
return Err(LimboError::InternalError("Connection closed".to_string()));
873893
}
874-
if state.is_interrupted() {
894+
if matches!(state.execution_state, ProgramExecutionState::Interrupting) {
875895
self.abort(mv_store, &pager, None, state);
876896
return Ok(StepResult::Interrupt);
877897
}
898+
878899
if let Some(io) = &state.io_completions {
879900
if !io.finished() {
880901
io.set_waker(waker);

sqlite3/src/lib.rs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
use std::ffi::{self, CStr, CString};
55
use std::num::{NonZero, NonZeroUsize};
66
use tracing::trace;
7-
use turso_core::{CheckpointMode, LimboError, Value};
7+
use turso_core::{CheckpointMode, LimboError, ProgramExecutionState, Value};
88

99
use std::sync::{Arc, Mutex};
1010

@@ -325,13 +325,32 @@ pub unsafe extern "C" fn sqlite3_prepare_v2(
325325
SQLITE_OK
326326
}
327327

328+
unsafe fn stmt_run_to_completion(stmt: *mut sqlite3_stmt) -> ffi::c_int {
329+
let stmt_ref = &mut *stmt;
330+
while stmt_ref.stmt.execution_state() == ProgramExecutionState::Run {
331+
let result = sqlite3_step(stmt);
332+
if result != SQLITE_DONE && result != SQLITE_ROW {
333+
return result;
334+
}
335+
}
336+
SQLITE_OK
337+
}
338+
328339
#[no_mangle]
329340
pub unsafe extern "C" fn sqlite3_finalize(stmt: *mut sqlite3_stmt) -> ffi::c_int {
330341
if stmt.is_null() {
331342
return SQLITE_MISUSE;
332343
}
333344
let stmt_ref = &mut *stmt;
334345

346+
// first, finalize any execution if it was unfinished
347+
// (for example, many drivers can consume just one row and finalize statement after that, while there still can be work to do)
348+
// (this is necessary because queries like INSERT INTO t VALUES (1), (2), (3) RETURNING id return values within a transaction)
349+
let result = stmt_run_to_completion(stmt);
350+
if result != SQLITE_OK {
351+
return result;
352+
}
353+
335354
if !stmt_ref.db.is_null() {
336355
let db = &mut *stmt_ref.db;
337356
let mut db_inner = db.inner.lock().unwrap();
@@ -669,6 +688,13 @@ fn split_sql_statements(sql: &str) -> Vec<&str> {
669688
#[no_mangle]
670689
pub unsafe extern "C" fn sqlite3_reset(stmt: *mut sqlite3_stmt) -> ffi::c_int {
671690
let stmt = &mut *stmt;
691+
// first, finalize any execution if it was unfinished
692+
// (for example, many drivers can consume just one row and finalize statement after that, while there still can be work to do)
693+
// (this is necessary because queries like INSERT INTO t VALUES (1), (2), (3) RETURNING id return values within a transaction)
694+
let result = stmt_run_to_completion(stmt);
695+
if result != SQLITE_OK {
696+
return result;
697+
}
672698
stmt.stmt.reset();
673699
stmt.clear_text_cache();
674700
SQLITE_OK
@@ -1420,10 +1446,10 @@ unsafe extern "C" fn sqlite_get_table_cb(
14201446
for i in 0..n_column {
14211447
let value = *argv.add(i as usize);
14221448
let value_cstring = if !value.is_null() {
1423-
let len = libc::strlen(value);
1424-
let mut buf = Vec::with_capacity(len + 1);
1425-
libc::strncpy(buf.as_mut_ptr() as *mut ffi::c_char, value, len);
1426-
buf.set_len(len + 1);
1449+
let value_cstr = CStr::from_ptr(value).to_bytes();
1450+
let len = value_cstr.len();
1451+
let mut buf = vec![0u8; len + 1];
1452+
buf[0..len].copy_from_slice(value_cstr);
14271453
CString::from_vec_with_nul(buf).unwrap()
14281454
} else {
14291455
CString::new("NULL").unwrap()

sqlite3/tests/sqlite3_tests.c

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ void test_sqlite3_column_type();
2020
void test_sqlite3_column_decltype();
2121
void test_sqlite3_next_stmt();
2222
void test_sqlite3_table_column_metadata();
23+
void test_sqlite3_insert_returning();
2324

2425
int allocated = 0;
2526

@@ -39,6 +40,7 @@ int main(void)
3940
test_sqlite3_column_decltype();
4041
test_sqlite3_next_stmt();
4142
test_sqlite3_table_column_metadata();
43+
test_sqlite3_insert_returning();
4244
return 0;
4345
}
4446

@@ -746,3 +748,47 @@ void test_sqlite3_table_column_metadata()
746748
printf("sqlite3_table_column_metadata test passed\n");
747749
sqlite3_close(db);
748750
}
751+
752+
void test_sqlite3_insert_returning()
753+
{
754+
sqlite3 *db;
755+
sqlite3_stmt *stmt;
756+
char *err_msg = NULL;
757+
int rc;
758+
759+
rc = sqlite3_open(":memory:", &db);
760+
assert(rc == SQLITE_OK);
761+
762+
rc = sqlite3_exec(db,
763+
"CREATE TABLE t(x)",
764+
NULL, NULL, &err_msg);
765+
assert(rc == SQLITE_OK);
766+
767+
if (err_msg)
768+
{
769+
sqlite3_free(err_msg);
770+
err_msg = NULL;
771+
}
772+
rc = sqlite3_prepare_v2(db, "INSERT INTO t (x) VALUES (1), (2), (3) RETURNING x;", -1, &stmt, NULL);
773+
assert(rc == SQLITE_OK);
774+
775+
rc = sqlite3_step(stmt);
776+
assert(rc == SQLITE_ROW);
777+
778+
rc = sqlite3_finalize(stmt);
779+
assert(rc == SQLITE_OK);
780+
781+
rc = sqlite3_prepare_v2(db, "SELECT COUNT(*) FROM t;", -1, &stmt, NULL);
782+
assert(rc == SQLITE_OK);
783+
784+
rc = sqlite3_step(stmt);
785+
assert(rc == SQLITE_ROW);
786+
787+
sqlite_int64 fetched = sqlite3_column_int64(stmt, 0);
788+
assert(fetched == 3);
789+
790+
sqlite3_finalize(stmt);
791+
sqlite3_close(db);
792+
793+
printf("test_sqlite3_insert_retuning test passed\n");
794+
}

0 commit comments

Comments
 (0)