diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index e059236c1c..da7fe814fe 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -58,7 +58,7 @@ jobs: - uses: actions/checkout@v3 - name: Clippy run: | - cargo clippy --workspace --all-features --all-targets -- --deny=warnings + cargo clippy --workspace --all-features --lib --bins -- --deny=warnings -D clippy::unwrap_used simulator: runs-on: blacksmith-4vcpu-ubuntu-2404 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1aaa08f320..0f411547f4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -128,6 +128,14 @@ echo -1 | sudo tee /proc/sys/kernel/perf_event_paranoid cargo bench --bench benchmark -- --profile-time=5 ``` +## Coding Style + +### Error handling + +* Turso is a library, which means we prefer error handling over crashing. +* However, we care about correctness more than crashing, which means we use assertions liberally. +* We don't use `unwrap()` in Turso code, except for tests. Instead, we implement error handling or -- when not possible -- use expect() to clearly communicate that this is an invariant (assertion). + ## Debugging bugs ### Query execution debugging diff --git a/cli/build.rs b/cli/build.rs index a461c7652f..96d93bf094 100644 --- a/cli/build.rs +++ b/cli/build.rs @@ -13,9 +13,9 @@ fn main() { println!("cargo::rerun-if-changed=build.rs"); println!("cargo::rerun-if-changed=manuals"); - let out_dir = env::var_os("OUT_DIR").unwrap(); - let syntax = - SyntaxDefinition::load_from_str(include_str!("./SQL.sublime-syntax"), false, None).unwrap(); + let out_dir = env::var_os("OUT_DIR").expect("OUT_DIR not set by cargo"); + let syntax = SyntaxDefinition::load_from_str(include_str!("./SQL.sublime-syntax"), false, None) + .expect("failed to load SQL syntax definition"); let mut ps = SyntaxSet::new().into_builder(); ps.add(syntax); let ps = ps.build(); @@ -23,5 +23,5 @@ fn main() { &ps, Path::new(&out_dir).join("SQL_syntax_set_dump.packdump"), ) - .unwrap(); + .expect("failed to dump syntax set"); } diff --git a/core/build.rs b/core/build.rs index 270fae9255..5d6a8828de 100644 --- a/core/build.rs +++ b/core/build.rs @@ -8,7 +8,7 @@ fn main() { println!("cargo::rerun-if-changed=build.rs"); } - let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap()); + let out_dir = PathBuf::from(std::env::var("OUT_DIR").expect("OUT_DIR not set by cargo")); let built_file = out_dir.join("built.rs"); built::write_built_file().expect("Failed to acquire build-time information"); @@ -19,7 +19,8 @@ fn main() { &built_file, format!( "{}\npub const BUILT_TIME_SQLITE: &str = \"{}\";\n", - fs::read_to_string(&built_file).unwrap(), + fs::read_to_string(&built_file) + .expect("built.rs should exist after built::write_built_file()"), sqlite_date ), ) diff --git a/core/ext/dynamic.rs b/core/ext/dynamic.rs index c31c222290..0e6538647c 100644 --- a/core/ext/dynamic.rs +++ b/core/ext/dynamic.rs @@ -4,9 +4,10 @@ use crate::{ }; #[cfg(not(target_family = "wasm"))] use libloading::{Library, Symbol}; +use parking_lot::Mutex; use std::{ ffi::{c_char, CString}, - sync::{Arc, Mutex, OnceLock}, + sync::{Arc, OnceLock}, }; use turso_ext::{ExtensionApi, ExtensionApiRef, ExtensionEntryPoint, ResultCode, VfsImpl}; @@ -52,12 +53,7 @@ impl Connection { let result_code = unsafe { entry(api_ptr) }; if result_code.is_ok() { let extensions = get_extension_libraries(); - extensions - .lock() - .map_err(|_| { - LimboError::ExtensionError("Error locking extension libraries".to_string()) - })? - .push((Arc::new(lib), api_ref)); + extensions.lock().push((Arc::new(lib), api_ref)); if self.is_db_initialized() { self.parse_schema_rows()?; } @@ -156,10 +152,7 @@ fn register_static_vfs_modules(_api: &mut ExtensionApi) { } pub fn add_vfs_module(name: String, vfs: Arc) { - let mut modules = VFS_MODULES - .get_or_init(|| Mutex::new(Vec::new())) - .lock() - .unwrap(); + let mut modules = VFS_MODULES.get_or_init(|| Mutex::new(Vec::new())).lock(); if !modules.iter().any(|v| v.0 == name) { modules.push((name, vfs)); } @@ -169,7 +162,6 @@ pub fn list_vfs_modules() -> Vec { VFS_MODULES .get_or_init(|| Mutex::new(Vec::new())) .lock() - .unwrap() .iter() .map(|v| v.0.clone()) .collect() @@ -179,6 +171,5 @@ pub fn get_vfs_modules() -> Vec { VFS_MODULES .get_or_init(|| Mutex::new(Vec::new())) .lock() - .unwrap() .clone() } diff --git a/core/ext/vtab_xconnect.rs b/core/ext/vtab_xconnect.rs index 7cd1321cf2..d3ea08ca1a 100644 --- a/core/ext/vtab_xconnect.rs +++ b/core/ext/vtab_xconnect.rs @@ -42,7 +42,7 @@ pub unsafe extern "C" fn execute( let args_slice = &mut std::slice::from_raw_parts_mut(args, arg_count as usize); for (i, val) in args_slice.iter_mut().enumerate() { stmt.bind_at( - NonZeroUsize::new(i + 1).unwrap(), + NonZeroUsize::new(i + 1).expect("i cannot be negative"), Value::from_ffi(std::mem::take(val)).unwrap_or(Value::Null), ); } diff --git a/core/functions/datetime.rs b/core/functions/datetime.rs index 703f5ad23c..506a8d3715 100644 --- a/core/functions/datetime.rs +++ b/core/functions/datetime.rs @@ -49,7 +49,9 @@ where return Value::Null; } - let value = values.next().unwrap(); + let Some(value) = values.next() else { + return Value::Null; + }; let value = value.as_value_ref(); let format_str = if matches!( value, @@ -80,11 +82,15 @@ where { let values = values.into_iter(); if values.len() == 0 { - let now = parse_naive_date_time(Value::build_text("now")).unwrap(); + let Some(now) = parse_naive_date_time(Value::build_text("now")) else { + return Value::Null; + }; return format_dt(now, output_type, false); } let mut values = values.peekable(); - let first = values.peek().unwrap(); + let Some(first) = values.peek() else { + return Value::Null; + }; if let Some(mut dt) = parse_naive_date_time(first) { // if successful, treat subsequent entries as modifiers modify_dt(&mut dt, values.skip(1), output_type) @@ -231,18 +237,19 @@ fn apply_modifier(dt: &mut NaiveDateTime, modifier: &str, n_floor: &mut i64) -> } Modifier::StartOfMonth => { *dt = NaiveDate::from_ymd_opt(dt.year(), dt.month(), 1) - .unwrap() - .and_hms_opt(0, 0, 0) - .unwrap(); + .and_then(|d| d.and_hms_opt(0, 0, 0)) + .ok_or_else(|| InvalidModifier("failed to construct start of month".to_string()))?; } Modifier::StartOfYear => { *dt = NaiveDate::from_ymd_opt(dt.year(), 1, 1) - .unwrap() - .and_hms_opt(0, 0, 0) - .unwrap(); + .and_then(|d| d.and_hms_opt(0, 0, 0)) + .ok_or_else(|| InvalidModifier("failed to construct start of year".to_string()))?; } Modifier::StartOfDay => { - *dt = dt.date().and_hms_opt(0, 0, 0).unwrap(); + *dt = dt + .date() + .and_hms_opt(0, 0, 0) + .ok_or_else(|| InvalidModifier("failed to construct start of day".to_string()))?; } Modifier::Weekday(day) => { let current_day = dt.weekday().num_days_from_sunday(); @@ -259,11 +266,18 @@ fn apply_modifier(dt: &mut NaiveDateTime, modifier: &str, n_floor: &mut i64) -> } Modifier::Utc => { // TODO: handle datetime('now', 'utc') no-op - let local_dt = chrono::Local.from_local_datetime(dt).unwrap(); + let local_dt = chrono::Local + .from_local_datetime(dt) + .single() + .ok_or_else(|| { + InvalidModifier("ambiguous local datetime during DST transition".to_string()) + })?; *dt = local_dt.with_timezone(&Utc).naive_utc(); } Modifier::Subsec => { - *dt = dt.with_nanosecond(dt.nanosecond()).unwrap(); + *dt = dt + .with_nanosecond(dt.nanosecond()) + .ok_or_else(|| InvalidModifier("failed to set nanoseconds".to_string()))?; return Ok(true); } } @@ -482,7 +496,7 @@ fn get_date_time_from_time_value_string(value: &str) -> Option { // First, try to parse as date-only format if let Ok(date) = NaiveDate::parse_from_str(value, date_only_format) { - return Some(date.and_time(NaiveTime::from_hms_opt(0, 0, 0).unwrap())); + return NaiveTime::from_hms_opt(0, 0, 0).map(|time| date.and_time(time)); } for format in &datetime_formats { @@ -598,10 +612,9 @@ fn is_leap_second(dt: &NaiveDateTime) -> bool { fn get_max_datetime_exclusive() -> NaiveDateTime { // The maximum date in SQLite is 9999-12-31 - NaiveDateTime::new( - NaiveDate::from_ymd_opt(10000, 1, 1).unwrap(), - NaiveTime::from_hms_milli_opt(00, 00, 00, 000).unwrap(), - ) + let date = NaiveDate::from_ymd_opt(10000, 1, 1).expect("10000-01-01 is valid"); + let time = NaiveTime::from_hms_milli_opt(00, 00, 00, 000).expect("00:00:00.000 is valid"); + NaiveDateTime::new(date, time) } /// Modifier doc https://www.sqlite.org/lang_datefunc.html#modifiers @@ -824,8 +837,15 @@ where return Value::Null; } - let start = parse_naive_date_time(values.next().unwrap()); - let end = parse_naive_date_time(values.next().unwrap()); + let Some(start_val) = values.next() else { + return Value::Null; + }; + let Some(end_val) = values.next() else { + return Value::Null; + }; + + let start = parse_naive_date_time(start_val); + let end = parse_naive_date_time(end_val); match (start, end) { (Some(start), Some(end)) => { diff --git a/core/io/clock.rs b/core/io/clock.rs index 06edc65e39..e135aa59da 100644 --- a/core/io/clock.rs +++ b/core/io/clock.rs @@ -72,7 +72,8 @@ impl std::ops::Add for Instant { type Output = Instant; fn add(self, rhs: Duration) -> Self::Output { - self.checked_add_duration(&rhs).unwrap() + self.checked_add_duration(&rhs) + .expect("duration addition overflow") } } @@ -80,7 +81,8 @@ impl std::ops::Sub for Instant { type Output = Instant; fn sub(self, rhs: Duration) -> Self::Output { - self.checked_sub_duration(&rhs).unwrap() + self.checked_sub_duration(&rhs) + .expect("duration subtraction underflow") } } diff --git a/core/io/completions.rs b/core/io/completions.rs index a1f038a545..1fcf851fb6 100644 --- a/core/io/completions.rs +++ b/core/io/completions.rs @@ -272,7 +272,9 @@ impl Completion { } pub(super) fn get_inner(&self) -> &Arc { - self.inner.as_ref().unwrap() + self.inner + .as_ref() + .expect("completion inner should be initialized") } pub fn needs_link(&self) -> bool { diff --git a/core/io/memory.rs b/core/io/memory.rs index d7f57fcd40..885ab2a8da 100644 --- a/core/io/memory.rs +++ b/core/io/memory.rs @@ -59,7 +59,12 @@ impl IO for MemoryIO { }), ); } - Ok(files.get(path).unwrap().clone()) + Ok(files + .get(path) + .ok_or(crate::LimboError::InternalError( + "file should exist after insert".to_string(), + ))? + .clone()) } fn remove_file(&self, path: &str) -> Result<()> { let mut files = self.files.lock(); diff --git a/core/mvcc/cursor.rs b/core/mvcc/cursor.rs index 2819bd5c31..98df1a4e20 100644 --- a/core/mvcc/cursor.rs +++ b/core/mvcc/cursor.rs @@ -5,7 +5,7 @@ use crate::mvcc::database::{MVTableId, MvStore, Row, RowID, RowKey, RowVersionSt use crate::storage::btree::{BTreeCursor, BTreeKey, CursorTrait}; use crate::translate::plan::IterationDirection; use crate::types::{IOResult, ImmutableRecord, RecordCursor, SeekKey, SeekOp, SeekResult}; -use crate::{return_if_io, Result}; +use crate::{return_if_io, LimboError, Result}; use crate::{Pager, Value}; use std::any::Any; use std::cell::{Ref, RefCell}; @@ -110,7 +110,9 @@ impl MvccLazyCursor { }; { let mut record = self.get_immutable_record_or_create(); - let record = record.as_mut().unwrap(); + let record = record.as_mut().ok_or(LimboError::InternalError( + "immutable record not initialized".to_string(), + ))?; record.invalidate(); record.start_serialization(&row.data); } @@ -119,7 +121,10 @@ impl MvccLazyCursor { Ref::filter_map(self.reusable_immutable_record.borrow(), |opt| { opt.as_ref() }) - .unwrap(); + .ok() + .ok_or(LimboError::InternalError( + "immutable record not initialized".to_string(), + ))?; Ok(IOResult::Done(Some(record_ref))) } } @@ -587,10 +592,21 @@ impl CursorTrait for MvccLazyCursor { panic!("BTreeKey::maybe_rowid() should return Some(rowid) for table rowid keys"); }; let row_id = RowID::new(self.table_id, RowKey::Int(rowid)); - let record_buf = key.get_record().unwrap().get_payload().to_vec(); + let record_buf = key + .get_record() + .ok_or(LimboError::InternalError( + "BTreeKey should have a record".to_string(), + ))? + .get_payload() + .to_vec(); let num_columns = match key { BTreeKey::IndexKey(record) => record.column_count(), - BTreeKey::TableRowId((_, record)) => record.as_ref().unwrap().column_count(), + BTreeKey::TableRowId((_, record)) => record + .as_ref() + .ok_or(LimboError::InternalError( + "TableRowId should have a record".to_string(), + ))? + .column_count(), }; let row = crate::mvcc::database::Row::new(row_id, record_buf, num_columns); @@ -770,7 +786,7 @@ impl CursorTrait for MvccLazyCursor { fn invalidate_record(&mut self) { self.get_immutable_record_or_create() .as_mut() - .unwrap() + .expect("immutable record should be initialized") .invalidate(); self.record_cursor.borrow_mut().invalidate(); } diff --git a/core/mvcc/database/checkpoint_state_machine.rs b/core/mvcc/database/checkpoint_state_machine.rs index 51774463c1..dafb488066 100644 --- a/core/mvcc/database/checkpoint_state_machine.rs +++ b/core/mvcc/database/checkpoint_state_machine.rs @@ -9,8 +9,8 @@ use crate::storage::pager::CreateBTreeFlags; use crate::storage::wal::{CheckpointMode, TursoRwLock}; use crate::types::{IOCompletions, IOResult, ImmutableRecord, RecordCursor}; use crate::{ - CheckpointResult, Completion, Connection, IOExt, Pager, Result, TransactionState, Value, - ValueRef, + CheckpointResult, Completion, Connection, IOExt, LimboError, Pager, Result, TransactionState, + Value, ValueRef, }; use parking_lot::RwLock; use std::collections::{HashMap, HashSet}; @@ -194,9 +194,12 @@ impl CheckpointStateMachine { if version.row.id.table_id == SQLITE_SCHEMA_MVCC_TABLE_ID { let row_data = ImmutableRecord::from_bin_record(version.row.data.clone()); let mut record_cursor = RecordCursor::new(); - record_cursor.parse_full_header(&row_data).unwrap(); - if let ValueRef::Integer(root_page) = - record_cursor.get_value(&row_data, 3).unwrap() + record_cursor + .parse_full_header(&row_data) + .expect("failed to parse record header"); + if let ValueRef::Integer(root_page) = record_cursor + .get_value(&row_data, 3) + .expect("failed to get column 3 from sqlite_schema") { if is_delete { let table_id = self @@ -207,7 +210,7 @@ impl CheckpointStateMachine { entry.value().is_some_and(|r| r == root_page as u64) }) .map(|entry| *entry.key()) - .unwrap(); // This assumes a valid mapping exists. + .expect("table_id to rootpage mapping should exist"); self.destroyed_tables.insert(table_id); // We might need to create or destroy a B-tree in the pager during checkpoint if a row in root page 1 is deleted or created. @@ -351,8 +354,11 @@ impl CheckpointStateMachine { } let (num_columns, table_id, special_write) = { - let (row_version, special_write) = - self.get_current_row_version(write_set_index).unwrap(); + let (row_version, special_write) = self + .get_current_row_version(write_set_index) + .ok_or(LimboError::InternalError( + "row version not found in write set".to_string(), + ))?; ( row_version.row.column_count, row_version.row.id.table_id, @@ -421,13 +427,15 @@ impl CheckpointStateMachine { .unwrap_or_else(|| { panic!( "Table ID does not have a root page: {table_id}, row_version: {:?}", - self.get_current_row_version(write_set_index).unwrap() + self.get_current_row_version(write_set_index) + .expect("row version should exist") ) }); root_page.value().unwrap_or_else(|| { panic!( "Table ID does not have a root page: {table_id}, row_version: {:?}", - self.get_current_row_version(write_set_index).unwrap() + self.get_current_row_version(write_set_index) + .expect("row version should exist") ) }) }; @@ -445,11 +453,16 @@ impl CheckpointStateMachine { .value() .expect("Table ID does not have a root page") }; - let (row_version, _) = - self.get_current_row_version_mut(write_set_index).unwrap(); + let (row_version, _) = self + .get_current_row_version_mut(write_set_index) + .ok_or(LimboError::InternalError( + "row version not found in write set".to_string(), + ))?; let record = ImmutableRecord::from_bin_record(row_version.row.data.clone()); let mut record_cursor = RecordCursor::new(); - record_cursor.parse_full_header(&record).unwrap(); + record_cursor + .parse_full_header(&record) + .map_err(|e| LimboError::InternalError(e.to_string()))?; let values = record_cursor.get_values(&record); let mut values = values .into_iter() @@ -471,7 +484,9 @@ impl CheckpointStateMachine { cursor }; - let (row_version, _) = self.get_current_row_version(write_set_index).unwrap(); + let (row_version, _) = self.get_current_row_version(write_set_index).ok_or( + LimboError::InternalError("row version not found in write set".to_string()), + )?; // Check if this is an insert or delete if row_version.end.is_some() { @@ -495,7 +510,12 @@ impl CheckpointStateMachine { CheckpointState::WriteRowStateMachine { write_set_index } => { let write_set_index = *write_set_index; - let write_row_state_machine = self.write_row_state_machine.as_mut().unwrap(); + let write_row_state_machine = + self.write_row_state_machine + .as_mut() + .ok_or(LimboError::InternalError( + "write_row_state_machine not initialized".to_string(), + ))?; match write_row_state_machine.step(&())? { IOResult::IO(io) => Ok(TransitionResult::Io(io)), @@ -511,7 +531,12 @@ impl CheckpointStateMachine { CheckpointState::DeleteRowStateMachine { write_set_index } => { let write_set_index = *write_set_index; - let delete_row_state_machine = self.delete_row_state_machine.as_mut().unwrap(); + let delete_row_state_machine = + self.delete_row_state_machine + .as_mut() + .ok_or(LimboError::InternalError( + "delete_row_state_machine not initialized".to_string(), + ))?; match delete_row_state_machine.step(&())? { IOResult::IO(io) => Ok(TransitionResult::Io(io)), @@ -536,17 +561,13 @@ impl CheckpointStateMachine { if self.update_transaction_state { self.connection.set_tx_state(TransactionState::None); } - let header = self - .pager - .io - .block(|| { - self.pager.with_header_mut(|header| { - header.schema_cookie = - self.connection.db.schema.lock().schema_version.into(); - *header - }) + let header = self.pager.io.block(|| { + self.pager.with_header_mut(|header| { + header.schema_cookie = + self.connection.db.schema.lock().schema_version.into(); + *header }) - .unwrap(); + })?; self.mvstore.global_header.write().replace(header); Ok(TransitionResult::Continue) } @@ -598,7 +619,11 @@ impl CheckpointStateMachine { self.checkpoint_lock.unlock(); self.finalize(&())?; Ok(TransitionResult::Done( - self.checkpoint_result.take().unwrap(), + self.checkpoint_result + .take() + .ok_or(LimboError::InternalError( + "checkpoint_result not set".to_string(), + ))?, )) } } diff --git a/core/mvcc/database/mod.rs b/core/mvcc/database/mod.rs index 9d5451572c..37d1359405 100644 --- a/core/mvcc/database/mod.rs +++ b/core/mvcc/database/mod.rs @@ -775,7 +775,10 @@ impl StateTransition for CommitStateMachine { let schema = connection.schema.read().clone(); connection.db.update_schema_if_newer(schema); } - let tx = mvcc_store.txs.get(&self.tx_id).unwrap(); + let tx = mvcc_store + .txs + .get(&self.tx_id) + .ok_or(LimboError::NoSuchTransactionID(self.tx_id.to_string()))?; let tx_unlocked = tx.value(); self.header.write().replace(*tx_unlocked.header.read()); tracing::trace!("end_commit_logical_log(tx_id={})", self.tx_id); @@ -784,7 +787,10 @@ impl StateTransition for CommitStateMachine { return Ok(TransitionResult::Continue); } CommitState::CommitEnd { end_ts } => { - let tx = mvcc_store.txs.get(&self.tx_id).unwrap(); + let tx = mvcc_store + .txs + .get(&self.tx_id) + .ok_or(LimboError::NoSuchTransactionID(self.tx_id.to_string()))?; let tx_unlocked = tx.value(); tx_unlocked .state @@ -1329,7 +1335,10 @@ impl MvStore { /// and `None` otherwise. pub fn read(&self, tx_id: TxID, id: RowID) -> Result> { tracing::trace!("read(tx_id={}, id={:?})", tx_id, id); - let tx = self.txs.get(&tx_id).unwrap(); + let tx = self + .txs + .get(&tx_id) + .ok_or(LimboError::NoSuchTransactionID(tx_id.to_string()))?; let tx = tx.value(); assert_eq!(tx.state, TransactionState::Active); if let Some(row_versions) = self.rows.get(&id) { @@ -1441,7 +1450,10 @@ impl MvStore { direction ); - let tx = self.txs.get(&tx_id).unwrap(); + let tx = self + .txs + .get(&tx_id) + .expect("transaction should exist in txs map"); let tx = tx.value(); if direction == IterationDirection::Forwards { let min_bound = RowID { @@ -1506,18 +1518,19 @@ impl MvStore { row_id, ); - let tx = self.txs.get(&tx_id).unwrap(); + let tx = self + .txs + .get(&tx_id) + .expect("transaction should exist in txs map"); let tx = tx.value(); - let versions = self.rows.get(&RowID { + let Some(versions) = self.rows.get(&RowID { table_id, row_id: row_id.clone(), - }); - if versions.is_none() { + }) else { return RowVersionState::NotFound; - } - let versions = versions.unwrap(); + }; let versions = versions.value().read(); - let last_version = versions.last().unwrap(); + let last_version = versions.last().expect("versions should not be empty"); if last_version.is_visible_to(tx, &self.txs) { RowVersionState::LiveVersion } else { @@ -1550,7 +1563,10 @@ impl MvStore { ) -> Option { tracing::trace!("seek_rowid(bound={:?}, lower_bound={})", bound, lower_bound,); - let tx = self.txs.get(&tx_id).unwrap(); + let tx = self + .txs + .get(&tx_id) + .expect("transaction should exist in txs map"); let tx = tx.value(); let res = if lower_bound { self.rows @@ -1592,7 +1608,11 @@ impl MvStore { let unlock = || self.blocking_checkpoint_lock.unlock(); let tx_id = maybe_existing_tx_id.unwrap_or_else(|| self.get_tx_id()); let begin_ts = if let Some(tx_id) = maybe_existing_tx_id { - self.txs.get(&tx_id).unwrap().value().begin_ts + self.txs + .get(&tx_id) + .ok_or(LimboError::NoSuchTransactionID(tx_id.to_string()))? + .value() + .begin_ts } else { self.get_timestamp() }; @@ -1672,11 +1692,14 @@ impl MvStore { fn get_new_transaction_database_header(&self, pager: &Arc) -> DatabaseHeader { if self.global_header.read().is_none() { - pager.io.block(|| pager.maybe_allocate_page1()).unwrap(); + pager + .io + .block(|| pager.maybe_allocate_page1()) + .expect("failed to allocate page1"); let header = pager .io .block(|| pager.with_header(|header| *header)) - .unwrap(); + .expect("failed to read database header"); // TODO: We initialize header here, maybe this needs more careful handling self.global_header.write().replace(header); tracing::debug!( @@ -1685,7 +1708,10 @@ impl MvStore { ); header } else { - let header = self.global_header.read().unwrap(); + let header = self + .global_header + .read() + .expect("global_header should be initialized"); tracing::debug!("get_transaction_database_header read: header={:?}", header); header } @@ -1707,7 +1733,10 @@ impl MvStore { F: Fn(&DatabaseHeader) -> T, { if let Some(tx_id) = tx_id { - let tx = self.txs.get(tx_id).unwrap(); + let tx = self + .txs + .get(tx_id) + .ok_or(LimboError::NoSuchTransactionID(tx_id.to_string()))?; let header = tx.value(); let header = header.header.read(); tracing::debug!("with_header read: header={:?}", header); @@ -1715,7 +1744,9 @@ impl MvStore { } else { let header = self.global_header.read(); tracing::debug!("with_header read: header={:?}", header); - Ok(f(header.as_ref().unwrap())) + Ok(f(header.as_ref().ok_or(LimboError::InternalError( + "global_header not initialized".to_string(), + ))?)) } } @@ -1724,14 +1755,19 @@ impl MvStore { F: Fn(&mut DatabaseHeader) -> T, { if let Some(tx_id) = tx_id { - let tx = self.txs.get(tx_id).unwrap(); + let tx = self + .txs + .get(tx_id) + .ok_or(LimboError::NoSuchTransactionID(tx_id.to_string()))?; let header = tx.value(); let mut header = header.header.write(); tracing::debug!("with_header_mut read: header={:?}", header); Ok(f(&mut header)) } else { let mut header = self.global_header.write(); - let header = header.as_mut().unwrap(); + let header = header.as_mut().ok_or(LimboError::InternalError( + "global_header not initialized".to_string(), + ))?; tracing::debug!("with_header_mut write: header={:?}", header); Ok(f(header)) } @@ -1767,7 +1803,10 @@ impl MvStore { /// transactions. pub fn commit_load_tx(&self, tx_id: TxID) { let end_ts = self.get_timestamp(); - let tx = self.txs.get(&tx_id).unwrap(); + let tx = self + .txs + .get(&tx_id) + .expect("transaction should exist in txs map"); let tx = tx.value(); for rowid in &tx.write_set { let rowid = rowid.value(); @@ -1809,7 +1848,10 @@ impl MvStore { /// /// * `tx_id` - The ID of the transaction to abort. pub fn rollback_tx(&self, tx_id: TxID, _pager: Arc, connection: &Connection) { - let tx_unlocked = self.txs.get(&tx_id).unwrap(); + let tx_unlocked = self + .txs + .get(&tx_id) + .expect("transaction should exist in txs map"); let tx = tx_unlocked.value(); *connection.mv_tx.write() = None; assert!(tx.state == TransactionState::Active || tx.state == TransactionState::Preparing); @@ -1986,7 +2028,13 @@ impl MvStore { fn get_begin_timestamp(&self, ts_or_id: &Option) -> u64 { match ts_or_id { Some(TxTimestampOrID::Timestamp(ts)) => *ts, - Some(TxTimestampOrID::TxID(tx_id)) => self.txs.get(tx_id).unwrap().value().begin_ts, + Some(TxTimestampOrID::TxID(tx_id)) => { + self.txs + .get(tx_id) + .expect("transaction should exist in txs map") + .value() + .begin_ts + } // This function is intended to be used in the ordering of row versions within the row version chain in `insert_version_raw`. // // The row version chain should be append-only (aside from garbage collection), @@ -2091,7 +2139,7 @@ impl MvStore { let tx_id = 0; self.begin_load_tx(pager.clone())?; loop { - match reader.next_record(&pager.io).unwrap() { + match reader.next_record(&pager.io)? { StreamingResult::InsertRow { row, rowid } => { if rowid.table_id == SQLITE_SCHEMA_MVCC_TABLE_ID { // Sqlite schema row version inserts @@ -2099,7 +2147,9 @@ impl MvStore { let record = ImmutableRecord::from_bin_record(row_data); let mut record_cursor = RecordCursor::new(); let mut record_values = record_cursor.get_values(&record); - let val = record_values.nth(3).unwrap()?; + let val = record_values.nth(3).ok_or(LimboError::InternalError( + "Expected at least 4 columns in sqlite_schema".to_string(), + ))??; let ValueRef::Integer(root_page) = val else { panic!("Expected integer value for root page, got {val:?}"); }; @@ -2181,7 +2231,9 @@ pub(crate) fn is_write_write_conflict( ) -> bool { match rv.end { Some(TxTimestampOrID::TxID(rv_end)) => { - let te = txs.get(&rv_end).unwrap(); + let te = txs + .get(&rv_end) + .expect("transaction should exist in txs map"); let te = te.value(); if te.tx_id == tx.tx_id { return false; @@ -2207,7 +2259,9 @@ fn is_begin_visible(txs: &SkipMap, tx: &Transaction, rv: &Row match rv.begin { Some(TxTimestampOrID::Timestamp(rv_begin_ts)) => tx.begin_ts >= rv_begin_ts, Some(TxTimestampOrID::TxID(rv_begin)) => { - let tb = txs.get(&rv_begin).unwrap(); + let tb = txs + .get(&rv_begin) + .expect("transaction should exist in txs map"); let tb = tb.value(); let visible = match tb.state.load() { TransactionState::Active => tx.tx_id == tb.tx_id && rv.end.is_none(), @@ -2272,7 +2326,13 @@ fn stmt_get_all_rows(stmt: &mut Statement) -> Result>> { let step = stmt.step()?; match step { StepResult::Row => { - rows.push(stmt.row().unwrap().get_values().cloned().collect()); + rows.push( + stmt.row() + .ok_or(LimboError::InternalError("No row available".to_string()))? + .get_values() + .cloned() + .collect(), + ); } StepResult::Done => { break; diff --git a/core/mvcc/persistent_storage/logical_log.rs b/core/mvcc/persistent_storage/logical_log.rs index 5d9aff7071..c35b27333a 100644 --- a/core/mvcc/persistent_storage/logical_log.rs +++ b/core/mvcc/persistent_storage/logical_log.rs @@ -8,7 +8,8 @@ use crate::{ types::ImmutableRecord, Buffer, Completion, CompletionError, LimboError, Result, }; -use std::sync::{Arc, RwLock}; +use parking_lot::RwLock; +use std::sync::Arc; use crate::File; @@ -275,7 +276,7 @@ pub struct StreamingLogicalLogReader { impl StreamingLogicalLogReader { pub fn new(file: Arc) -> Self { - let file_size = file.size().unwrap() as usize; + let file_size = file.size().expect("failed to get file size") as usize; Self { file, offset: 0, @@ -292,7 +293,7 @@ impl StreamingLogicalLogReader { let header = Arc::new(RwLock::new(LogHeader::default())); let completion: Box = Box::new(move |res| { let header = header.clone(); - let mut header = header.write().unwrap(); + let mut header = header.write(); let Ok((buf, bytes_read)) = res else { tracing::error!("couldn't ready log err={:?}", res,); return; @@ -307,7 +308,9 @@ impl StreamingLogicalLogReader { } let buf = buf.as_slice(); header.version = buf[0]; - header.salt = u64::from_be_bytes(buf[1..9].try_into().unwrap()); + header.salt = u64::from_be_bytes([ + buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8], + ]); header.encrypted = buf[10]; tracing::trace!("LogicalLog header={:?}", header); }); @@ -400,36 +403,50 @@ impl StreamingLogicalLogReader { fn consume_u8(&mut self, io: &Arc) -> Result { self.read_more_data(io, 1)?; - let r = self.buffer.read().unwrap()[self.buffer_offset]; + let r = self.buffer.read()[self.buffer_offset]; self.buffer_offset += 1; Ok(r) } fn consume_i64(&mut self, io: &Arc) -> Result { self.read_more_data(io, 8)?; - let r = i64::from_be_bytes( - self.buffer.read().unwrap()[self.buffer_offset..self.buffer_offset + 8] - .try_into() - .unwrap(), - ); + let buf = self.buffer.read(); + let offset = self.buffer_offset; + let r = i64::from_be_bytes([ + buf[offset], + buf[offset + 1], + buf[offset + 2], + buf[offset + 3], + buf[offset + 4], + buf[offset + 5], + buf[offset + 6], + buf[offset + 7], + ]); self.buffer_offset += 8; Ok(r) } fn consume_u64(&mut self, io: &Arc) -> Result { self.read_more_data(io, 8)?; - let r = u64::from_be_bytes( - self.buffer.read().unwrap()[self.buffer_offset..self.buffer_offset + 8] - .try_into() - .unwrap(), - ); + let buf = self.buffer.read(); + let offset = self.buffer_offset; + let r = u64::from_be_bytes([ + buf[offset], + buf[offset + 1], + buf[offset + 2], + buf[offset + 3], + buf[offset + 4], + buf[offset + 5], + buf[offset + 6], + buf[offset + 7], + ]); self.buffer_offset += 8; Ok(r) } fn consume_varint(&mut self, io: &Arc) -> Result<(u64, usize)> { self.read_more_data(io, 9)?; - let buffer_guard = self.buffer.read().unwrap(); + let buffer_guard = self.buffer.read(); let buffer = &buffer_guard[self.buffer_offset..]; let (v, n) = read_varint(buffer)?; self.buffer_offset += n; @@ -438,14 +455,13 @@ impl StreamingLogicalLogReader { fn consume_buffer(&mut self, io: &Arc, amount: usize) -> Result> { self.read_more_data(io, amount)?; - let buffer = - self.buffer.read().unwrap()[self.buffer_offset..self.buffer_offset + amount].to_vec(); + let buffer = self.buffer.read()[self.buffer_offset..self.buffer_offset + amount].to_vec(); self.buffer_offset += amount; Ok(buffer) } - fn get_buffer(&self) -> std::sync::RwLockReadGuard<'_, Vec> { - self.buffer.read().unwrap() + fn get_buffer(&self) -> parking_lot::RwLockReadGuard<'_, Vec> { + self.buffer.read() } pub fn read_more_data(&mut self, io: &Arc, need: usize) -> Result<()> { @@ -459,7 +475,7 @@ impl StreamingLogicalLogReader { let buffer = self.buffer.clone(); let completion: Box = Box::new(move |res| { let buffer = buffer.clone(); - let mut buffer = buffer.write().unwrap(); + let mut buffer = buffer.write(); let Ok((buf, bytes_read)) = res else { tracing::trace!("couldn't ready log err={:?}", res,); return; @@ -477,13 +493,13 @@ impl StreamingLogicalLogReader { self.offset += to_read; // cleanup consumed bytes // this could be better for sure - let _ = self.buffer.write().unwrap().drain(0..self.buffer_offset); + let _ = self.buffer.write().drain(0..self.buffer_offset); self.buffer_offset = 0; Ok(()) } fn bytes_can_read(&self) -> usize { - self.buffer.read().unwrap().len() - self.buffer_offset + self.buffer.read().len() - self.buffer_offset } } diff --git a/core/schema.rs b/core/schema.rs index a503403724..f5708bdd71 100644 --- a/core/schema.rs +++ b/core/schema.rs @@ -632,7 +632,12 @@ impl Schema { } else { let table = self .get_btree_table(&unparsed_sql_from_index.table_name) - .unwrap(); + .ok_or_else(|| { + LimboError::InternalError(format!( + "table {} not found", + unparsed_sql_from_index.table_name + )) + })?; let index = Index::from_sql( syms, &unparsed_sql_from_index.sql, @@ -652,12 +657,20 @@ impl Schema { // The SQL statement parser enforces that the column definitions come first, and compounds are defined after that, // e.g. CREATE TABLE t (a, b, UNIQUE(a, b)), and you can't do something like CREATE TABLE t (a, b, UNIQUE(a, b), c); // Hence, we can process the singles first (unique_set.columns.len() == 1), and then the compounds (unique_set.columns.len() > 1). - let table = self.get_btree_table(&automatic_index.0).unwrap(); + let table = self.get_btree_table(&automatic_index.0).ok_or_else(|| { + LimboError::InternalError(format!("table {} not found", automatic_index.0)) + })?; let mut automatic_indexes = automatic_index.1; automatic_indexes.reverse(); // reverse so we can pop() without shifting array elements, while still processing in left-to-right order let mut pk_index_added = false; for unique_set in table.unique_sets.iter().filter(|us| us.columns.len() == 1) { - let col_name = &unique_set.columns.first().unwrap().0; + let col_name = &unique_set + .columns + .first() + .ok_or_else(|| { + LimboError::InternalError("unique set has no columns".to_string()) + })? + .0; let Some((pos_in_table, column)) = table.get_column(col_name) else { return Err(LimboError::ParseError(format!( "Column {col_name} not found in table {}", @@ -669,7 +682,10 @@ impl Schema { // rowid alias, no index needed continue; } - assert!(table.primary_key_columns.first().unwrap().0 == *col_name, "trying to add a primary key index for column that is not the first column in the primary key: {} != {}", table.primary_key_columns.first().unwrap().0, col_name); + let first_pk_col = table.primary_key_columns.first().ok_or_else(|| { + LimboError::InternalError("table has no primary key columns".to_string()) + })?; + assert!(first_pk_col.0 == *col_name, "trying to add a primary key index for column that is not the first column in the primary key: {} != {}", first_pk_col.0, col_name); // Add single column primary key index assert!( !pk_index_added, @@ -677,18 +693,24 @@ impl Schema { table.name ); pk_index_added = true; + let root_page = automatic_indexes.pop().ok_or_else(|| { + LimboError::InternalError("not enough automatic indexes".to_string()) + })?; self.add_index(Arc::new(Index::automatic_from_primary_key( table.as_ref(), - automatic_indexes.pop().unwrap(), + root_page, 1, )?))?; } else { // Add single column unique index if let Some(autoidx) = automatic_indexes.pop() { + let first_col = unique_set.columns.first().ok_or_else(|| { + LimboError::InternalError("unique set has no columns".to_string()) + })?; self.add_index(Arc::new(Index::automatic_from_unique( table.as_ref(), autoidx, - vec![(pos_in_table, unique_set.columns.first().unwrap().1)], + vec![(pos_in_table, first_col.1)], )?))?; } } @@ -703,9 +725,12 @@ impl Schema { table.name ); pk_index_added = true; + let root_page = automatic_indexes.pop().ok_or_else(|| { + LimboError::InternalError("not enough automatic indexes".to_string()) + })?; self.add_index(Arc::new(Index::automatic_from_primary_key( table.as_ref(), - automatic_indexes.pop().unwrap(), + root_page, unique_set.columns.len(), )?))?; } else { @@ -721,9 +746,12 @@ impl Schema { }; column_indices_and_sort_orders.push((pos_in_table, *sort_order)); } + let root_page = automatic_indexes.pop().ok_or_else(|| { + LimboError::InternalError("not enough automatic indexes".to_string()) + })?; self.add_index(Arc::new(Index::automatic_from_unique( table.as_ref(), - automatic_indexes.pop().unwrap(), + root_page, column_indices_and_sort_orders, )?))?; } @@ -840,7 +868,10 @@ impl Schema { // Check if this is a DBSP state table if table.name.starts_with(DBSP_TABLE_PREFIX) { // Extract version and view name from __turso_internal_dbsp_state_v_ - let suffix = table.name.strip_prefix(DBSP_TABLE_PREFIX).unwrap(); + let suffix = table + .name + .strip_prefix(DBSP_TABLE_PREFIX) + .expect("checked starts_with above"); // Parse version and view name (format: "_") if let Some(underscore_pos) = suffix.find('_') { @@ -890,7 +921,9 @@ impl Schema { // Check if this is an index for a DBSP state table if table_name.starts_with(DBSP_TABLE_PREFIX) { // Extract version and view name from __turso_internal_dbsp_state_v_ - let suffix = table_name.strip_prefix(DBSP_TABLE_PREFIX).unwrap(); + let suffix = table_name + .strip_prefix(DBSP_TABLE_PREFIX) + .expect("checked starts_with above"); // Parse version and view name (format: "_") if let Some(underscore_pos) = suffix.find('_') { @@ -2049,7 +2082,7 @@ pub fn create_table(tbl_name: &str, body: &CreateTableBody, root_page: i64) -> R let unique_set_w_only_rowid_alias = unique_sets.iter().position(|us| { us.is_primary_key && us.columns.len() == 1 - && &us.columns.first().unwrap().0 == col.name.as_ref().unwrap() + && us.columns.first().map(|c| &c.0) == col.name.as_ref() }); if let Some(u) = unique_set_w_only_rowid_alias { unique_sets.remove(u); @@ -2221,7 +2254,9 @@ impl ResolvedFkRef { } // special case: if FK uses a rowid alias on child, and rowid changed if self.child_cols.len() == 1 { - let (i, col) = child_tbl.get_column(&self.child_cols[0]).unwrap(); + let (i, col) = child_tbl + .get_column(&self.child_cols[0]) + .expect("child_cols[0] exists in child table"); if col.is_rowid_alias() && updated_child_positions.contains(&i) { return true; } @@ -2662,7 +2697,9 @@ impl Index { col_name, table.name ))); }; - let (_, column) = table.get_column(col_name).unwrap(); + let (_, column) = table + .get_column(col_name) + .expect("checked column exists above"); primary_keys.push(IndexColumn { name: normalize_ident(col_name), order: *order, @@ -2704,7 +2741,7 @@ impl Index { .iter() .find(|(pos, _)| *pos == pos_in_table)?; Some(IndexColumn { - name: normalize_ident(col.name.as_ref().unwrap()), + name: normalize_ident(col.name.as_ref().expect("column has a name")), order: *sort_order, pos_in_table: *pos_in_table, collation: col.collation_opt(), diff --git a/core/storage/btree.rs b/core/storage/btree.rs index 7baedd4d25..0eb9d6f051 100644 --- a/core/storage/btree.rs +++ b/core/storage/btree.rs @@ -723,12 +723,9 @@ impl BTreeCursor { } let mut record_cursor_ref = self.record_cursor.borrow_mut(); let record_cursor = record_cursor_ref.deref_mut(); - let rowid = match self - .get_immutable_record() - .as_ref() - .unwrap() - .last_value(record_cursor) - { + let immutable_record_opt = self.get_immutable_record(); + let immutable_record = immutable_record_opt.as_ref()?; + let rowid = match immutable_record.last_value(record_cursor) { Some(Ok(ValueRef::Integer(rowid))) => rowid, _ => unreachable!( "index where has_rowid() is true should have an integer rowid as the last value" @@ -844,7 +841,7 @@ impl BTreeCursor { let left_child_page = self .stack .get_page_contents_at_level(old_top_idx) - .unwrap() + .ok_or_else(|| LimboError::Corrupt("Page not found at level in stack".to_string()))? .cell_interior_read_left_child_page(cell_idx); if page_type == PageType::IndexInterior { @@ -895,7 +892,9 @@ impl BTreeCursor { next_page, remaining_to_read, page, - } = read_overflow_state.as_mut().unwrap(); + } = read_overflow_state.as_mut().ok_or_else(|| { + LimboError::Corrupt("overflow read state unexpectedly None".into()) + })?; turso_assert!(page.is_loaded(), "page should be loaded"); tracing::debug!(next_page, remaining_to_read, "reading overflow page"); @@ -925,12 +924,11 @@ impl BTreeCursor { std::mem::swap(payload, &mut payload_swap); let mut reuse_immutable = self.get_immutable_record_or_create(); - reuse_immutable.as_mut().unwrap().invalidate(); - - reuse_immutable - .as_mut() - .unwrap() - .start_serialization(&payload_swap); + let reuse_immutable_ref = reuse_immutable.as_mut().ok_or_else(|| { + LimboError::Corrupt("immutable record unexpectedly None after create".into()) + })?; + reuse_immutable_ref.invalidate(); + reuse_immutable_ref.start_serialization(&payload_swap); self.record_cursor.borrow_mut().invalidate(); let _ = read_overflow_state.take(); @@ -1018,7 +1016,7 @@ impl BTreeCursor { } let usable_size = self.usable_space(); - let cell = contents.cell_get(cell_idx, usable_size).unwrap(); + let cell = contents.cell_get(cell_idx, usable_size)?; let (payload, payload_size, first_overflow_page) = match cell { BTreeCell::TableLeafCell(cell) => { @@ -1068,17 +1066,15 @@ impl BTreeCursor { } if amount > 0 { - if first_overflow_page.is_none() { - return Err(LimboError::Corrupt( - "Expected overflow page but none found".into(), - )); - } + let first_overflow_page = first_overflow_page.ok_or_else(|| { + LimboError::Corrupt("Expected overflow page but none found".into()) + })?; let overflow_size = usable_size - 4; let pages_to_skip = offset / overflow_size as u32; let page_offset = offset % overflow_size as u32; // Read page - let (page, c) = self.read_page(first_overflow_page.unwrap() as i64)?; + let (page, c) = self.read_page(first_overflow_page as i64)?; self.state = CursorState::ReadWritePayload(PayloadOverflowWithOffset::SkipOverflowPages { @@ -1523,7 +1519,9 @@ impl BTreeCursor { let left_child_page = self .stack .get_page_contents_at_level(old_top_idx) - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt("Page not found at level in stack".to_string()) + })? .cell_interior_read_left_child_page(nearest_matching_cell); self.stack.set_cell_index(nearest_matching_cell as i32); let (mem_page, c) = self.read_page(left_child_page as i64)?; @@ -1540,7 +1538,9 @@ impl BTreeCursor { match self .stack .get_page_contents_at_level(old_top_idx) - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt("Page not found at level in stack".to_string()) + })? .rightmost_pointer() { Some(right_most_pointer) => { @@ -1563,7 +1563,7 @@ impl BTreeCursor { let cell_rowid = self .stack .get_page_contents_at_level(old_top_idx) - .unwrap() + .ok_or_else(|| LimboError::Corrupt("Page not found at level in stack".to_string()))? .cell_table_interior_read_rowid(cur_cell_idx as usize)?; // in sqlite btrees left child pages have <= keys. // table btrees can have a duplicate rowid in the interior cell, so for example if we are looking for rowid=10, @@ -1686,7 +1686,9 @@ impl BTreeCursor { match self .stack .get_page_contents_at_level(old_top_idx) - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt("Page not found at level in stack".to_string()) + })? .rightmost_pointer() { Some(right_most_pointer) => { @@ -1708,7 +1710,9 @@ impl BTreeCursor { let matching_cell = self .stack .get_page_contents_at_level(old_top_idx) - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt("Page not found at level in stack".to_string()) + })? .cell_get(leftmost_matching_cell, self.usable_space())?; self.stack.set_cell_index(leftmost_matching_cell as i32); // we don't advance in case of forward iteration and index tree internal nodes because we will visit this node going up. @@ -1728,7 +1732,9 @@ impl BTreeCursor { }; { - let page = self.stack.get_page_at_level(old_top_idx).unwrap(); + let page = self.stack.get_page_at_level(old_top_idx).ok_or_else(|| { + LimboError::Corrupt("Page not found at level in stack".to_string()) + })?; turso_assert!( page.get().id != *left_child_page as usize, "corrupt: current page and left child page of cell {} are both {}", @@ -1753,7 +1759,7 @@ impl BTreeCursor { let cell = self .stack .get_page_contents_at_level(old_top_idx) - .unwrap() + .ok_or_else(|| LimboError::Corrupt("Page not found at level in stack".to_string()))? .cell_get(cur_cell_idx as usize, self.usable_space())?; let BTreeCell::IndexInteriorCell(IndexInteriorCell { payload, @@ -1770,29 +1776,33 @@ impl BTreeCursor { } else { self.get_immutable_record_or_create() .as_mut() - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt("Failed to get or create immutable record".to_string()) + })? .invalidate(); self.get_immutable_record_or_create() .as_mut() - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt("Failed to get or create immutable record".to_string()) + })? .start_serialization(payload); self.record_cursor.borrow_mut().invalidate(); }; let (target_leaf_page_is_in_left_subtree, is_eq) = { let record = self.get_immutable_record(); - let record = record.as_ref().unwrap(); - - let interior_cell_vs_index_key = record_comparer - .compare( - record, - &key_values, - self.index_info - .as_ref() - .expect("indexbtree_move_to without index_info"), - 0, - tie_breaker, - ) - .unwrap(); + let record = record.as_ref().ok_or_else(|| { + LimboError::Corrupt("Immutable record not available".to_string()) + })?; + + let interior_cell_vs_index_key = record_comparer.compare( + record, + &key_values, + self.index_info + .as_ref() + .expect("indexbtree_move_to without index_info"), + 0, + tie_breaker, + )?; // in sqlite btrees left child pages have <= keys. // in general, in forwards iteration we want to find the first key that matches the seek condition. @@ -2087,7 +2097,11 @@ impl BTreeCursor { < self .stack .get_page_contents_at_level(old_top_idx) - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt( + "Page not found at level in stack".to_string(), + ) + })? .cell_count() as i32; self.has_record.set(has_record); @@ -2107,7 +2121,7 @@ impl BTreeCursor { let cell = self .stack .get_page_contents_at_level(old_top_idx) - .unwrap() + .ok_or_else(|| LimboError::Corrupt("Page not found at level in stack".to_string()))? .cell_get(cur_cell_idx as usize, self.usable_space())?; let BTreeCell::IndexLeafCell(IndexLeafCell { payload, @@ -2123,11 +2137,15 @@ impl BTreeCursor { } else { self.get_immutable_record_or_create() .as_mut() - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt("Failed to get or create immutable record".to_string()) + })? .invalidate(); self.get_immutable_record_or_create() .as_mut() - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt("Failed to get or create immutable record".to_string()) + })? .start_serialization(payload); self.record_cursor.borrow_mut().invalidate(); @@ -2139,7 +2157,7 @@ impl BTreeCursor { self.index_info .as_ref() .expect("indexbtree_seek without index_info"), - ); + )?; if found { nearest_matching_cell.set(Some(cur_cell_idx as usize)); match iter_dir { @@ -2181,14 +2199,14 @@ impl BTreeCursor { seek_op: SeekOp, record_comparer: &RecordCompare, index_info: &IndexInfo, - ) -> (Ordering, bool) { + ) -> Result<(Ordering, bool)> { let record = self.get_immutable_record(); - let record = record.as_ref().unwrap(); + let record = record + .as_ref() + .ok_or_else(|| LimboError::Corrupt("Immutable record not available".to_string()))?; let tie_breaker = get_tie_breaker_from_seek_op(seek_op); - let cmp = record_comparer - .compare(record, key_values, index_info, 0, tie_breaker) - .unwrap(); + let cmp = record_comparer.compare(record, key_values, index_info, 0, tie_breaker)?; let found = match seek_op { SeekOp::GT => cmp.is_gt(), @@ -2198,7 +2216,7 @@ impl BTreeCursor { SeekOp::LE { eq_only: false } => cmp.is_le(), SeekOp::LT => cmp.is_lt(), }; - (cmp, found) + Ok((cmp, found)) } #[instrument(skip_all, level = Level::DEBUG)] @@ -2315,15 +2333,15 @@ impl BTreeCursor { record_values.as_slice(), self.get_immutable_record() .as_ref() - .unwrap() + .ok_or_else(|| LimboError::Corrupt("Immutable record not available".to_string()))? .get_values().as_slice(), - &self.index_info.as_ref().unwrap().key_info, + &self.index_info.as_ref().ok_or_else(|| LimboError::Corrupt("Index info not available".to_string()))?.key_info, ); if cmp == Ordering::Equal { tracing::debug!("IndexLeafCell: found exact match with cell_idx={cell_idx}, overwriting"); self.has_record.set(true); let CursorState::Write(write_state) = &mut self.state else { - panic!("expected write state"); + return Err(LimboError::Corrupt("Expected write state".to_string())); }; *write_state = WriteState::Overwrite { page, @@ -2534,18 +2552,28 @@ impl BTreeCursor { if cur_page_contents.page_type() == PageType::TableLeaf && cur_page_contents.overflow_cells.len() == 1 { - let overflow_cell_is_last = - cur_page_contents.overflow_cells.first().unwrap().index - == cur_page_contents.cell_count(); + let overflow_cell_is_last = cur_page_contents + .overflow_cells + .first() + .ok_or_else(|| { + LimboError::Corrupt("Overflow cells list is empty".to_string()) + })? + .index + == cur_page_contents.cell_count(); if overflow_cell_is_last { let parent = self .stack .get_page_at_level(self.stack.current() - 1) - .expect("parent page should be on the stack"); + .ok_or_else(|| { + LimboError::Corrupt( + "Parent page not found on stack".to_string(), + ) + })?; let parent_contents = parent.get_contents(); if parent.get().id != 1 - && parent_contents.rightmost_pointer().unwrap() - == cur_page.get().id as u32 + && parent_contents.rightmost_pointer().ok_or_else(|| { + LimboError::Corrupt("Rightmost pointer not found".to_string()) + })? == cur_page.get().id as u32 { // If all of the following are true, we can use the balance_quick() fast path: // - The page is a table leaf page @@ -2711,7 +2739,12 @@ impl BTreeCursor { self.stack.retreat(); } - let parent_page = self.stack.get_page_at_level(parent_page_idx).unwrap(); + let parent_page = + self.stack + .get_page_at_level(parent_page_idx) + .ok_or_else(|| { + LimboError::Corrupt("Parent page not found on stack".to_string()) + })?; let parent_contents = parent_page.get_contents(); if !past_rightmost_pointer && over_cell_count > 0 { // The ONLY way we can have an overflow cell in the parent is if we replaced an interior cell from a cell in the child, and that replacement did not fit. @@ -2721,7 +2754,10 @@ impl BTreeCursor { } else { turso_assert!(false, "{page_type:?} must have no overflow cells"); } - let overflow_cell = parent_contents.overflow_cells.first().unwrap(); + let overflow_cell = + parent_contents.overflow_cells.first().ok_or_else(|| { + LimboError::Corrupt("Overflow cell not found in parent".to_string()) + })?; let parent_page_cell_idx = self.stack.current_cell_index() as usize; // Parent page must be positioned at the divider cell that overflowed due to the replacement. turso_assert!( @@ -2777,7 +2813,9 @@ impl BTreeCursor { == parent_contents.cell_count(); // Get the right page pointer that we will need to update later let right_pointer = if last_sibling_is_right_pointer { - parent_contents.rightmost_pointer_raw().unwrap() + parent_contents.rightmost_pointer_raw().ok_or_else(|| { + LimboError::Corrupt("Rightmost pointer not found".to_string()) + })? } else { let max_overflow_cells = if matches!(page_type, PageType::IndexInterior) { 1 @@ -2857,9 +2895,19 @@ impl BTreeCursor { parent_contents.overflow_cells.len() == 1, "must have a single overflow cell in the parent, as a result of InteriorNodeReplacement" ); - let overflow_cell = parent_contents.overflow_cells.first().unwrap(); - pgno = - u32::from_be_bytes(overflow_cell.payload[0..4].try_into().unwrap()); + let overflow_cell = + parent_contents.overflow_cells.first().ok_or_else(|| { + LimboError::Corrupt( + "Overflow cell not found in parent".to_string(), + ) + })?; + pgno = u32::from_be_bytes( + overflow_cell.payload[0..4].try_into().map_err(|_| { + LimboError::Corrupt( + "Invalid page number in overflow cell".to_string(), + ) + })?, + ); } else { // grep for 'OVERFLOW CELL ADJUSTMENT' for explanation. // here we only subtract 1 if the divider cell has been shifted left, i.e. the overflow cell was placed to the left @@ -2910,19 +2958,25 @@ impl BTreeCursor { BalanceSubState::NonRootDoBalancing => { // Ensure all involved pages are in memory. let mut balance_info = balance_info.borrow_mut(); - let balance_info = balance_info.as_mut().unwrap(); + let balance_info = balance_info.as_mut().ok_or_else(|| { + LimboError::Corrupt("Balance info not available".to_string()) + })?; for page in balance_info .pages_to_balance .iter() .take(balance_info.sibling_count) { - let page = page.as_ref().unwrap(); + let page = page + .as_ref() + .expect("Page should exist in balance info during validation"); self.pager.add_dirty(page)?; #[cfg(debug_assertions)] let page_type_of_siblings = balance_info.pages_to_balance[0] .as_ref() - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt("Page not found in balance info".to_string()) + })? .get_contents() .page_type(); @@ -2945,7 +2999,10 @@ impl BTreeCursor { MAX_NEW_SIBLING_PAGES_AFTER_BALANCE] = [const { None }; MAX_NEW_SIBLING_PAGES_AFTER_BALANCE]; for i in (0..balance_info.sibling_count).rev() { - let sibling_page = balance_info.pages_to_balance[i].as_ref().unwrap(); + let sibling_page = + balance_info.pages_to_balance[i].as_ref().ok_or_else(|| { + LimboError::Corrupt("Page not found in balance info".to_string()) + })?; turso_assert!(sibling_page.is_loaded(), "sibling page is not loaded"); let sibling_contents = sibling_page.get_contents(); total_cells_to_redistribute += sibling_contents.cell_count(); @@ -2973,7 +3030,12 @@ impl BTreeCursor { parent_contents.overflow_cells.len() == 1, "must have a single overflow cell in the parent, as a result of InteriorNodeReplacement" ); - let overflow_cell = parent_contents.overflow_cells.first().unwrap(); + let overflow_cell = + parent_contents.overflow_cells.first().ok_or_else(|| { + LimboError::Corrupt( + "Overflow cell not found in parent".to_string(), + ) + })?; &overflow_cell.payload } else { // grep for 'OVERFLOW CELL ADJUSTMENT' for explanation. @@ -3034,7 +3096,9 @@ impl BTreeCursor { let page_type = balance_info.pages_to_balance[0] .as_ref() - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt("Page not found in balance info".to_string()) + })? .get_contents() .page_type(); tracing::debug!("balance_non_root(page_type={:?})", page_type); @@ -3046,7 +3110,9 @@ impl BTreeCursor { .take(balance_info.sibling_count) .enumerate() { - let old_page = old_page.as_ref().unwrap(); + let old_page = old_page.as_ref().ok_or_else(|| { + LimboError::Corrupt("Page not found in balance info".to_string()) + })?; let old_page_contents = old_page.get_contents(); let page_type = old_page_contents.page_type(); let max_local = payload_overflow_threshold_max(page_type, usable_space); @@ -3062,7 +3128,7 @@ impl BTreeCursor { max_local, min_local, page_type, - ); + )?; let buf = old_page_contents.as_ptr(); let cell_buf = &mut buf[cell_start..cell_start + cell_len]; // TODO(pere): make this reference and not copy @@ -3089,14 +3155,23 @@ impl BTreeCursor { // But we don't need the last divider as it will remain the same. let mut divider_cell = balance_info.divider_cell_payloads[i] .as_mut() - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt( + "Divider cell payload not available".to_string(), + ) + })? .as_mut_slice(); // TODO(pere): in case of old pages are leaf pages, so index leaf page, we need to strip page pointers // from divider cells in index interior pages (parent) because those should not be included. cells_inserted += 1; if !is_leaf { // This divider cell needs to be updated with new left pointer, - let right_pointer = old_page_contents.rightmost_pointer().unwrap(); + let right_pointer = + old_page_contents.rightmost_pointer().ok_or_else(|| { + LimboError::Corrupt( + "Rightmost pointer not found".to_string(), + ) + })?; divider_cell[..LEFT_CHILD_PTR_SIZE_BYTES] .copy_from_slice(&right_pointer.to_be_bytes()); } else { @@ -3148,7 +3223,9 @@ impl BTreeCursor { for i in 0..balance_info.sibling_count { cell_array.cell_count_per_page_cumulative[i] = old_cell_count_per_page_cumulative[i]; - let page = &balance_info.pages_to_balance[i].as_ref().unwrap(); + let page = &balance_info.pages_to_balance[i].as_ref().ok_or_else(|| { + LimboError::Corrupt("Page not found in balance info".to_string()) + })?; let page_contents = page.get_contents(); let free_space = compute_free_space(page_contents, usable_space); @@ -3409,18 +3486,26 @@ impl BTreeCursor { cell_array, sibling_count_new, .. - } = context.as_mut().unwrap(); + } = context.as_mut().ok_or_else(|| { + LimboError::Corrupt("Balance context not available".to_string()) + })?; let pager = self.pager.clone(); let mut balance_info = balance_info.borrow_mut(); - let balance_info = balance_info.as_mut().unwrap(); + let balance_info = balance_info.as_mut().ok_or_else(|| { + LimboError::Corrupt("Balance info not available".to_string()) + })?; let page_type = balance_info.pages_to_balance[0] .as_ref() - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt("Page not found in balance info".to_string()) + })? .get_contents() .page_type(); // Allocate pages or set dirty if not needed if *i < balance_info.sibling_count { - let page = balance_info.pages_to_balance[*i].as_ref().unwrap(); + let page = balance_info.pages_to_balance[*i].as_ref().ok_or_else(|| { + LimboError::Corrupt("Page not found in balance info".to_string()) + })?; turso_assert!(page.is_dirty(), "sibling page must be already marked dirty"); pages_to_balance_new[*i].replace(page.clone()); } else { @@ -3441,7 +3526,9 @@ impl BTreeCursor { continue; } else { *sub_state = BalanceSubState::NonRootDoBalancingFinish { - context: context.take().unwrap(), + context: context.take().ok_or_else(|| { + LimboError::Corrupt("Balance context not available".to_string()) + })?, }; } } @@ -3457,10 +3544,14 @@ impl BTreeCursor { }, } => { let mut balance_info = balance_info.borrow_mut(); - let balance_info = balance_info.as_mut().unwrap(); + let balance_info = balance_info.as_mut().ok_or_else(|| { + LimboError::Corrupt("Balance info not available".to_string()) + })?; let page_type = balance_info.pages_to_balance[0] .as_ref() - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt("Page not found in balance info".to_string()) + })? .get_contents() .page_type(); let parent_is_root = !self.stack.has_parent(); @@ -3477,7 +3568,15 @@ impl BTreeCursor { .take(sibling_count_new) .enumerate() { - page_numbers[i] = page.as_ref().unwrap().get().id; + page_numbers[i] = page + .as_ref() + .ok_or_else(|| { + LimboError::Corrupt( + "Page not found in balance info".to_string(), + ) + })? + .get() + .id; } page_numbers.sort(); for (page, new_id) in pages_to_balance_new @@ -3486,7 +3585,9 @@ impl BTreeCursor { .rev() .zip(page_numbers.iter().rev().take(sibling_count_new)) { - let page = page.as_ref().unwrap(); + let page = page + .as_ref() + .expect("Page should exist in balance info during validation"); if *new_id != page.get().id { page.get().id = *new_id; self.pager @@ -3503,7 +3604,12 @@ impl BTreeCursor { for page in pages_to_balance_new.iter().take(sibling_count_new) { tracing::debug!( "balance_non_root(new_sibling page_id={})", - page.as_ref().unwrap().get().id + page.as_ref() + .ok_or_else(|| LimboError::Corrupt( + "Page not found in balance info".to_string() + ))? + .get() + .id ); } } @@ -3519,7 +3625,9 @@ impl BTreeCursor { // therfore invalidating the pointer. let right_page_id = pages_to_balance_new[sibling_count_new - 1] .as_ref() - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt("Page not found in balance info".to_string()) + })? .get() .id as u32; let rightmost_pointer = balance_info.rightmost_pointer; @@ -3543,11 +3651,20 @@ impl BTreeCursor { let last_sibling_idx = balance_info.sibling_count - 1; let last_page = balance_info.pages_to_balance[last_sibling_idx] .as_ref() - .unwrap(); - let right_pointer = last_page.get_contents().rightmost_pointer().unwrap(); + .ok_or_else(|| { + LimboError::Corrupt("Page not found in balance info".to_string()) + })?; + let right_pointer = last_page + .get_contents() + .rightmost_pointer() + .ok_or_else(|| { + LimboError::Corrupt("Rightmost pointer not found".to_string()) + })?; let new_last_page = pages_to_balance_new[sibling_count_new - 1] .as_ref() - .unwrap(); + .ok_or_else(|| { + LimboError::Corrupt("Page not found in balance info".to_string()) + })?; new_last_page .get_contents() .write_rightmost_ptr(right_pointer); @@ -3564,7 +3681,9 @@ impl BTreeCursor { .take(sibling_count_new - 1) /* do not take last page */ { - let page = page.as_ref().unwrap(); + let page = page + .as_ref() + .expect("Page should exist in balance info during validation"); // e.g. if we have 3 pages and the leftmost child page has 3 cells, // then the divider cell idx is 3 in the flat cell array. let divider_cell_idx = cell_array.cell_count_up_to_page(sibling_page_idx); @@ -3675,7 +3794,9 @@ impl BTreeCursor { { // Let's ensure every page is pointed to by the divider cell or the rightmost pointer. for page in pages_to_balance_new.iter().take(sibling_count_new) { - let page = page.as_ref().unwrap(); + let page = page + .as_ref() + .expect("Page should exist in balance info during validation"); assert!( pages_pointed_to.contains(&(page.get().id as u32)), "page {} not pointed to by divider cell or rightmost pointer", @@ -3753,7 +3874,12 @@ impl BTreeCursor { cell_array.cell_count_up_to_page(page_idx) - start_new_cells, ) }; - let page = pages_to_balance_new[page_idx].as_ref().unwrap(); + let page = + pages_to_balance_new[page_idx].as_ref().ok_or_else(|| { + LimboError::Corrupt( + "Page not found in balance info".to_string(), + ) + })?; tracing::debug!("pre_edit_page(page={})", page.get().id); let page_contents = page.get_contents(); edit_page( @@ -3777,7 +3903,9 @@ impl BTreeCursor { } // TODO: vacuum support - let first_child_page = pages_to_balance_new[0].as_ref().unwrap(); + let first_child_page = pages_to_balance_new[0].as_ref().ok_or_else(|| { + LimboError::Corrupt("Page not found in balance info".to_string()) + })?; let first_child_contents = first_child_page.get_contents(); if parent_is_root && parent_contents.cell_count() == 0 @@ -3841,7 +3969,7 @@ impl BTreeCursor { sibling_count_new, right_page_id, usable_space, - ); + )?; // Balance-shallower case if sibling_count_new == 0 { @@ -3872,7 +4000,11 @@ impl BTreeCursor { } else { let balance_info = balance_info.borrow(); let balance_info = balance_info.as_ref().expect("must be balancing"); - let page = balance_info.pages_to_balance[*curr_page].as_ref().unwrap(); + let page = balance_info.pages_to_balance[*curr_page] + .as_ref() + .ok_or_else(|| { + LimboError::Corrupt("Page not found in balance info".to_string()) + })?; return_if_io!(self.pager.free_page(Some(page.clone()), page.get().id)); *sub_state = BalanceSubState::FreePages { curr_page: *curr_page + 1, @@ -3947,11 +4079,11 @@ impl BTreeCursor { sibling_count_new: usize, right_page_id: u32, usable_space: usize, - ) { + ) -> Result<()> { let mut valid = true; let mut current_index_cell = 0; for cell_idx in 0..parent_contents.cell_count() { - let cell = parent_contents.cell_get(cell_idx, usable_space).unwrap(); + let cell = parent_contents.cell_get(cell_idx, usable_space)?; match cell { BTreeCell::TableInteriorCell(table_interior_cell) => { let left_child_page = table_interior_cell.left_child_page; @@ -3982,7 +4114,11 @@ impl BTreeCursor { .take(sibling_count_new) .enumerate() { - let page = page.as_ref().unwrap(); + let page = page.as_ref().ok_or_else(|| { + LimboError::Corrupt( + "Page should exist in balance info during validation".to_string(), + ) + })?; let contents = page.get_contents(); debug_validate_cells!(contents, usable_space); // Cells are distributed in order @@ -4004,8 +4140,7 @@ impl BTreeCursor { contents, 0, usable_space, - ) - .unwrap(); + )?; match &cell { BTreeCell::TableInteriorCell(table_interior_cell) => { let left_child_page = table_interior_cell.left_child_page; @@ -4223,11 +4358,8 @@ impl BTreeCursor { contents, 0, usable_space, - ) - .unwrap(); - let parent_cell = parent_contents - .cell_get(cell_divider_idx, usable_space) - .unwrap(); + )?; + let parent_cell = parent_contents.cell_get(cell_divider_idx, usable_space)?; let rowid = match cell { BTreeCell::TableLeafCell(table_leaf_cell) => table_leaf_cell.rowid, _ => unreachable!(), @@ -4324,6 +4456,7 @@ impl BTreeCursor { valid, "corrupted database, cells were not balanced properly" ); + Ok(()) } /// Balance the root page. @@ -4703,7 +4836,10 @@ impl BTreeCursor { fn clear_root(&mut self, root_page: &PageRef) -> Result<()> { let page_ref = root_page.get(); - let contents = page_ref.contents.as_ref().unwrap(); + let contents = page_ref + .contents + .as_ref() + .ok_or_else(|| LimboError::Corrupt("Page contents not loaded".to_string()))?; let page_type = match contents.page_type() { PageType::TableLeaf | PageType::TableInterior => PageType::TableLeaf, @@ -4844,7 +4980,10 @@ impl BTreeCursor { self.valid_state = CursorValidState::Valid; return Ok(IOResult::Done(())); } - let ctx = self.context.take().unwrap(); + let ctx = self + .context + .take() + .ok_or_else(|| LimboError::Corrupt("Cursor context not available".to_string()))?; let seek_key = match ctx.key { CursorContextKey::TableRowId(rowid) => SeekKey::TableRowId(rowid), CursorContextKey::IndexKeyRowId(ref record) => SeekKey::IndexKey(record), @@ -4995,7 +5134,7 @@ impl CursorTrait for BTreeCursor { if !invalidated { let record_ref = Ref::filter_map(self.reusable_immutable_record.borrow(), |opt| opt.as_ref()) - .unwrap(); + .expect("Record should be valid after checking invalidated flag"); return Ok(IOResult::Done(Some(record_ref))); } @@ -5028,17 +5167,26 @@ impl CursorTrait for BTreeCursor { } else { self.get_immutable_record_or_create() .as_mut() - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt( + "Failed to get mutable reference to immutable record".to_string(), + ) + })? .invalidate(); self.get_immutable_record_or_create() .as_mut() - .unwrap() + .ok_or_else(|| { + LimboError::Corrupt( + "Failed to get mutable reference to immutable record".to_string(), + ) + })? .start_serialization(payload); self.record_cursor.borrow_mut().invalidate(); }; let record_ref = - Ref::filter_map(self.reusable_immutable_record.borrow(), |opt| opt.as_ref()).unwrap(); + Ref::filter_map(self.reusable_immutable_record.borrow(), |opt| opt.as_ref()) + .expect("Record should be valid after get_immutable_record_or_create"); Ok(IOResult::Done(Some(record_ref))) } @@ -5538,8 +5686,9 @@ impl CursorTrait for BTreeCursor { if cell_idx == contents.cell_count() { // Move to right child - // should be safe as contents is not a leaf page - let right_most_pointer = contents.rightmost_pointer().unwrap(); + let right_most_pointer = contents.rightmost_pointer().ok_or_else(|| { + LimboError::Corrupt("Interior page missing rightmost pointer".into()) + })?; self.stack.advance(); let (mem_page, c) = self.read_page(right_most_pointer as i64)?; self.stack.push(mem_page); @@ -5620,7 +5769,7 @@ impl CursorTrait for BTreeCursor { fn invalidate_record(&mut self) { self.get_immutable_record_or_create() .as_mut() - .unwrap() + .expect("Failed to get mutable reference to immutable record after creation") .invalidate(); self.record_cursor.borrow_mut().invalidate(); } @@ -5645,7 +5794,9 @@ impl CursorTrait for BTreeCursor { } fn get_index_info(&self) -> &IndexInfo { - self.index_info.as_ref().unwrap() + self.index_info + .as_ref() + .expect("Index info should be set for index cursors") } fn seek_end(&mut self) -> Result> { @@ -6346,7 +6497,9 @@ impl PageStack { return; } let current = self.current(); - let page = self.stack[current].as_ref().unwrap(); + let page = self.stack[current] + .as_ref() + .expect("Page should exist in stack at current position"); turso_assert!( page.is_pinned(), "parent page {} is not pinned", @@ -6394,14 +6547,18 @@ impl PageStack { /// This is the page that is currently being traversed. fn top(&self) -> Arc { let current = self.current(); - let page = self.stack[current].clone().unwrap(); + let page = self.stack[current] + .clone() + .expect("Page should exist in stack at current position"); turso_assert!(page.is_loaded(), "page should be loaded"); page } fn top_ref(&self) -> &Arc { let current = self.current(); - let page = self.stack[current].as_ref().unwrap(); + let page = self.stack[current] + .as_ref() + .expect("Page should exist in stack at current position"); turso_assert!(page.is_loaded(), "page should be loaded"); page } @@ -7584,7 +7741,9 @@ fn fill_cell_payload( cell_payload.extend_from_slice(&left_child_page.to_be_bytes()); } if matches!(page_type, PageType::TableLeaf) { - let int_key = int_key.unwrap(); + let int_key = int_key.ok_or_else(|| { + LimboError::Corrupt("int_key must be provided for TableLeaf pages".into()) + })?; write_varint_to_vec(record_buf.len() as u64, cell_payload); write_varint_to_vec(int_key as u64, cell_payload); } else { diff --git a/core/storage/checksum.rs b/core/storage/checksum.rs index 43bd0c5be4..4e21d81022 100644 --- a/core/storage/checksum.rs +++ b/core/storage/checksum.rs @@ -59,7 +59,16 @@ impl ChecksumContext { let actual_page = &page[..CHECKSUM_PAGE_SIZE - CHECKSUM_SIZE]; let stored_checksum_bytes = &page[CHECKSUM_PAGE_SIZE - CHECKSUM_SIZE..]; - let stored_checksum = u64::from_le_bytes(stored_checksum_bytes.try_into().unwrap()); + let stored_checksum = u64::from_le_bytes([ + stored_checksum_bytes[0], + stored_checksum_bytes[1], + stored_checksum_bytes[2], + stored_checksum_bytes[3], + stored_checksum_bytes[4], + stored_checksum_bytes[5], + stored_checksum_bytes[6], + stored_checksum_bytes[7], + ]); let computed_checksum = self.compute_checksum(actual_page); if stored_checksum != computed_checksum { diff --git a/core/storage/database.rs b/core/storage/database.rs index 3e1db81079..7f6ea31968 100644 --- a/core/storage/database.rs +++ b/core/storage/database.rs @@ -204,8 +204,8 @@ impl DatabaseStorage for DatabaseFile { return Err(LimboError::IntegerOverflow); }; let buffer = match &io_ctx.encryption_or_checksum { - EncryptionOrChecksum::Encryption(ctx) => encrypt_buffer(page_idx, buffer, ctx), - EncryptionOrChecksum::Checksum(ctx) => checksum_buffer(page_idx, buffer, ctx), + EncryptionOrChecksum::Encryption(ctx) => encrypt_buffer(page_idx, buffer, ctx)?, + EncryptionOrChecksum::Checksum(ctx) => checksum_buffer(page_idx, buffer, ctx)?, EncryptionOrChecksum::None => buffer, }; self.file.pwrite(pos, buffer, c) @@ -232,12 +232,12 @@ impl DatabaseStorage for DatabaseFile { .into_iter() .enumerate() .map(|(i, buffer)| encrypt_buffer(first_page_idx + i, buffer, ctx)) - .collect::>(), + .collect::>>()?, EncryptionOrChecksum::Checksum(ctx) => buffers .into_iter() .enumerate() .map(|(i, buffer)| checksum_buffer(first_page_idx + i, buffer, ctx)) - .collect::>(), + .collect::>>()?, EncryptionOrChecksum::None => buffers, }; let c = self.file.pwritev(pos, buffers, c)?; @@ -268,13 +268,20 @@ impl DatabaseFile { } } -fn encrypt_buffer(page_idx: usize, buffer: Arc, ctx: &EncryptionContext) -> Arc { - let encrypted_data = ctx.encrypt_page(buffer.as_slice(), page_idx).unwrap(); - Arc::new(Buffer::new(encrypted_data.to_vec())) +fn encrypt_buffer( + page_idx: usize, + buffer: Arc, + ctx: &EncryptionContext, +) -> Result> { + let encrypted_data = ctx.encrypt_page(buffer.as_slice(), page_idx)?; + Ok(Arc::new(Buffer::new(encrypted_data.to_vec()))) } -fn checksum_buffer(page_idx: usize, buffer: Arc, ctx: &ChecksumContext) -> Arc { - ctx.add_checksum_to_page(buffer.as_mut_slice(), page_idx) - .unwrap(); - buffer +fn checksum_buffer( + page_idx: usize, + buffer: Arc, + ctx: &ChecksumContext, +) -> Result> { + ctx.add_checksum_to_page(buffer.as_mut_slice(), page_idx)?; + Ok(buffer) } diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs index f0184dbc04..763c4eecba 100644 --- a/core/storage/encryption.rs +++ b/core/storage/encryption.rs @@ -94,11 +94,15 @@ impl EncryptionKey { match bytes.len() { 16 => { - let key: [u8; 16] = bytes.try_into().unwrap(); + let key: [u8; 16] = bytes.try_into().map_err(|_| { + LimboError::Corrupt("Failed to convert bytes to 16-byte key".into()) + })?; Ok(Self::Key128(key)) } 32 => { - let key: [u8; 32] = bytes.try_into().unwrap(); + let key: [u8; 32] = bytes.try_into().map_err(|_| { + LimboError::Corrupt("Failed to convert bytes to 32-byte key".into()) + })?; Ok(Self::Key256(key)) } _ => Err(LimboError::InvalidArgument(format!( diff --git a/core/storage/page_cache.rs b/core/storage/page_cache.rs index e8ac4e657a..98f2ef6346 100644 --- a/core/storage/page_cache.rs +++ b/core/storage/page_cache.rs @@ -135,15 +135,13 @@ impl PageCache { let mut cursor = self.queue.cursor_mut_from_ptr(self.clock_hand); cursor.move_next(); - if cursor.get().is_some() { - self.clock_hand = - cursor.as_cursor().get().unwrap() as *const _ as *mut PageCacheEntry; + if let Some(entry) = cursor.get() { + self.clock_hand = entry as *const _ as *mut PageCacheEntry; } else { // Reached end, wrap to front let front_cursor = self.queue.front_mut(); - if front_cursor.get().is_some() { - self.clock_hand = - front_cursor.as_cursor().get().unwrap() as *const _ as *mut PageCacheEntry; + if let Some(entry) = front_cursor.get() { + self.clock_hand = entry as *const _ as *mut PageCacheEntry; } else { self.clock_hand = std::ptr::null_mut(); } @@ -204,7 +202,12 @@ impl PageCache { if self.clock_hand.is_null() { // First entry - just push it self.queue.push_back(entry); - let entry_ptr = self.queue.back().get().unwrap() as *const _ as *mut PageCacheEntry; + let entry_ptr = self + .queue + .back() + .get() + .expect("queue should have entry after push_back") + as *const _ as *mut PageCacheEntry; self.map.insert(key, entry_ptr); self.clock_hand = entry_ptr; } else { diff --git a/core/storage/pager.rs b/core/storage/pager.rs index 465f0393e7..a9be266de2 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -214,7 +214,10 @@ impl Page { } pub fn get_contents(&self) -> &mut PageContent { - self.get().contents.as_mut().unwrap() + self.get() + .contents + .as_mut() + .expect("page contents should be initialized") } pub fn is_locked(&self) -> bool { @@ -750,7 +753,9 @@ impl Pager { ); let savepoints = savepoints.read(); - let cur_savepoint = savepoints.last().unwrap(); + let cur_savepoint = savepoints + .last() + .expect("savepoint must exist when writing to subjournal"); cur_savepoint.add_dirty_page(page_id as u32); cur_savepoint .write_offset @@ -760,7 +765,9 @@ impl Pager { let c = Completion::new_write(write_complete); let subjournal = self.subjournal.read(); - let subjournal = subjournal.as_ref().unwrap(); + let subjournal = subjournal + .as_ref() + .expect("subjournal must be open when writing page"); let c = subjournal.write_page(write_offset, page_size, buffer.clone(), c)?; assert!(c.succeeded(), "memory IO should complete immediately"); @@ -769,7 +776,12 @@ impl Pager { pub fn open_savepoint(&self, db_size: u32) -> Result<()> { self.open_subjournal()?; - let subjournal_offset = self.subjournal.read().as_ref().unwrap().size()?; + let subjournal_offset = self + .subjournal + .read() + .as_ref() + .expect("subjournal should be open after open_subjournal") + .size()?; // Currently as we only have anonymous savepoints opened at the start of a statement, // the subjournal offset should always be 0 as we should only have max 1 savepoint // opened at any given time. @@ -844,7 +856,13 @@ impl Pager { let page_id_buffer = Arc::new(self.buffer_pool.allocate(4)); let c = subjournal.read_page_number(current_offset, page_id_buffer.clone())?; assert!(c.succeeded(), "memory IO should complete immediately"); - let page_id = u32::from_be_bytes(page_id_buffer.as_slice()[0..4].try_into().unwrap()); + let page_id_slice = page_id_buffer.as_slice(); + let page_id = u32::from_be_bytes([ + page_id_slice[0], + page_id_slice[1], + page_id_slice[2], + page_id_slice[3], + ]); current_offset += 4; // Check if we've already rolled back this page or if the page is beyond the database size at the start of the savepoint @@ -892,7 +910,7 @@ impl Pager { .subjournal .read() .as_ref() - .unwrap() + .expect("subjournal must be open during rollback") .truncate(journal_start_offset)?; assert!( truncate_completion.succeeded(), @@ -1274,7 +1292,11 @@ impl Pager { tracing::debug!("Pager::allocate_overflow_page(id={})", page.get().id); // setup overflow page - let contents = page.get().contents.as_mut().unwrap(); + let contents = page + .get() + .contents + .as_mut() + .expect("page contents should be initialized"); let buf = contents.as_ptr(); buf.fill(0); @@ -1982,7 +2004,7 @@ impl Pager { .now() .to_system_time() .duration_since(self.commit_info.read().time.to_system_time()) - .unwrap() + .expect("commit time should be in the past") .as_millis() ); let (should_finish, result, completion) = { @@ -2011,7 +2033,7 @@ impl Pager { #[instrument(skip_all, level = Level::DEBUG)] pub fn wal_changed_pages_after(&self, frame_watermark: u64) -> Result> { - let wal = self.wal.as_ref().unwrap().borrow(); + let wal = self.wal.as_ref().expect("WAL should be enabled").borrow(); wal.changed_pages_after(frame_watermark) } @@ -2356,7 +2378,11 @@ impl Pager { let trunk_page_id = header.freelist_trunk_page.get(); - let contents = page.get().contents.as_mut().unwrap(); + let contents = page + .get() + .contents + .as_mut() + .expect("page contents should be initialized"); // Point to previous trunk contents.write_u32_no_offset(TRUNK_PAGE_NEXT_PAGE_OFFSET, trunk_page_id); // Zero leaf count @@ -2658,7 +2684,7 @@ impl Pager { let mut cache = self.page_cache.write(); cache .insert(page_key, richard_hipp_special_page.clone()) - .unwrap(); + .expect("cache insertion should succeed"); } // HIPP special page is assumed to zeroed and should never be read or written to by the BTREE new_db_size += 1; diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 887ed23e04..07f817f49a 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -468,7 +468,9 @@ impl PageContent { } pub fn page_type(&self) -> PageType { - self.read_u8(BTREE_PAGE_TYPE).try_into().unwrap() + self.read_u8(BTREE_PAGE_TYPE) + .try_into() + .expect("Invalid page type") } pub fn maybe_page_type(&self) -> Option { @@ -804,6 +806,7 @@ impl PageContent { min_local, page_type, ) + .expect("Failed to get cell raw region") } /// Get region(start end length) of a cell's payload @@ -815,13 +818,13 @@ impl PageContent { max_local: usize, min_local: usize, page_type: PageType, - ) -> (usize, usize) { + ) -> Result<(usize, usize)> { let buf = self.as_ptr(); assert!(idx < cell_count, "cell_get: idx out of bounds"); let start = self.cell_get_raw_start_offset(idx); let len = match page_type { PageType::IndexInterior => { - let (len_payload, n_payload) = read_varint(&buf[start + 4..]).unwrap(); + let (len_payload, n_payload) = read_varint(&buf[start + 4..])?; let (overflows, to_read) = payload_overflows(len_payload as usize, max_local, min_local, usable_size); if overflows { @@ -831,11 +834,11 @@ impl PageContent { } } PageType::TableInterior => { - let (_, n_rowid) = read_varint(&buf[start + 4..]).unwrap(); + let (_, n_rowid) = read_varint(&buf[start + 4..])?; 4 + n_rowid } PageType::IndexLeaf => { - let (len_payload, n_payload) = read_varint(&buf[start..]).unwrap(); + let (len_payload, n_payload) = read_varint(&buf[start..])?; let (overflows, to_read) = payload_overflows(len_payload as usize, max_local, min_local, usable_size); if overflows { @@ -849,8 +852,8 @@ impl PageContent { } } PageType::TableLeaf => { - let (len_payload, n_payload) = read_varint(&buf[start..]).unwrap(); - let (_, n_rowid) = read_varint(&buf[start + n_payload..]).unwrap(); + let (len_payload, n_payload) = read_varint(&buf[start..])?; + let (_, n_rowid) = read_varint(&buf[start + n_payload..])?; let (overflows, to_read) = payload_overflows(len_payload as usize, max_local, min_local, usable_size); if overflows { @@ -864,7 +867,7 @@ impl PageContent { } } }; - (start, len) + Ok((start, len)) } pub fn is_leaf(&self) -> bool { @@ -1269,7 +1272,10 @@ impl SmallVec { if self.extra_data.is_none() { self.extra_data = Some(Vec::new()); } - self.extra_data.as_mut().unwrap().push(value); + self.extra_data + .as_mut() + .expect("extra_data was just initialized") + .push(value); self.len += 1; } } @@ -1278,7 +1284,10 @@ impl SmallVec { assert!(self.extra_data.is_some()); assert!(index >= self.data.len()); let extra_data_index = index - self.data.len(); - let extra_data = self.extra_data.as_ref().unwrap(); + let extra_data = self + .extra_data + .as_ref() + .expect("extra_data existence was just asserted"); assert!(extra_data_index < extra_data.len()); extra_data[extra_data_index] } @@ -1806,14 +1815,14 @@ impl StreamingWalReader { let (page_sz, c1, c2, use_native, ok) = { let mut h = self.header.lock(); let s = buf.as_slice(); - h.magic = u32::from_be_bytes(s[0..4].try_into().unwrap()); - h.file_format = u32::from_be_bytes(s[4..8].try_into().unwrap()); - h.page_size = u32::from_be_bytes(s[8..12].try_into().unwrap()); - h.checkpoint_seq = u32::from_be_bytes(s[12..16].try_into().unwrap()); - h.salt_1 = u32::from_be_bytes(s[16..20].try_into().unwrap()); - h.salt_2 = u32::from_be_bytes(s[20..24].try_into().unwrap()); - h.checksum_1 = u32::from_be_bytes(s[24..28].try_into().unwrap()); - h.checksum_2 = u32::from_be_bytes(s[28..32].try_into().unwrap()); + h.magic = u32::from_be_bytes([s[0], s[1], s[2], s[3]]); + h.file_format = u32::from_be_bytes([s[4], s[5], s[6], s[7]]); + h.page_size = u32::from_be_bytes([s[8], s[9], s[10], s[11]]); + h.checkpoint_seq = u32::from_be_bytes([s[12], s[13], s[14], s[15]]); + h.salt_1 = u32::from_be_bytes([s[16], s[17], s[18], s[19]]); + h.salt_2 = u32::from_be_bytes([s[20], s[21], s[22], s[23]]); + h.checksum_1 = u32::from_be_bytes([s[24], s[25], s[26], s[27]]); + h.checksum_2 = u32::from_be_bytes([s[28], s[29], s[30], s[31]]); tracing::debug!("WAL header: {:?}", *h); let use_native = cfg!(target_endian = "big") == ((h.magic & 1) != 0); @@ -1874,12 +1883,12 @@ impl StreamingWalReader { let fh = &buf[pos..pos + WAL_FRAME_HEADER_SIZE]; let page = &buf[pos + WAL_FRAME_HEADER_SIZE..pos + frame_size]; - let page_number = u32::from_be_bytes(fh[0..4].try_into().unwrap()); - let db_size = u32::from_be_bytes(fh[4..8].try_into().unwrap()); - let s1 = u32::from_be_bytes(fh[8..12].try_into().unwrap()); - let s2 = u32::from_be_bytes(fh[12..16].try_into().unwrap()); - let c1 = u32::from_be_bytes(fh[16..20].try_into().unwrap()); - let c2 = u32::from_be_bytes(fh[20..24].try_into().unwrap()); + let page_number = u32::from_be_bytes([fh[0], fh[1], fh[2], fh[3]]); + let db_size = u32::from_be_bytes([fh[4], fh[5], fh[6], fh[7]]); + let s1 = u32::from_be_bytes([fh[8], fh[9], fh[10], fh[11]]); + let s2 = u32::from_be_bytes([fh[12], fh[13], fh[14], fh[15]]); + let c1 = u32::from_be_bytes([fh[16], fh[17], fh[18], fh[19]]); + let c2 = u32::from_be_bytes([fh[20], fh[21], fh[22], fh[23]]); if page_number == 0 { break; @@ -2062,12 +2071,12 @@ pub fn begin_read_wal_frame( } pub fn parse_wal_frame_header(frame: &[u8]) -> (WalFrameHeader, &[u8]) { - let page_number = u32::from_be_bytes(frame[0..4].try_into().unwrap()); - let db_size = u32::from_be_bytes(frame[4..8].try_into().unwrap()); - let salt_1 = u32::from_be_bytes(frame[8..12].try_into().unwrap()); - let salt_2 = u32::from_be_bytes(frame[12..16].try_into().unwrap()); - let checksum_1 = u32::from_be_bytes(frame[16..20].try_into().unwrap()); - let checksum_2 = u32::from_be_bytes(frame[20..24].try_into().unwrap()); + let page_number = u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]); + let db_size = u32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]); + let salt_1 = u32::from_be_bytes([frame[8], frame[9], frame[10], frame[11]]); + let salt_2 = u32::from_be_bytes([frame[12], frame[13], frame[14], frame[15]]); + let checksum_1 = u32::from_be_bytes([frame[16], frame[17], frame[18], frame[19]]); + let checksum_2 = u32::from_be_bytes([frame[20], frame[21], frame[22], frame[23]]); let header = WalFrameHeader { page_number, db_size, @@ -2204,16 +2213,17 @@ pub fn checksum_wal( let mut i = 0; if native_endian { while i < buf.len() { - let v0 = u32::from_ne_bytes(buf[i..i + 4].try_into().unwrap()); - let v1 = u32::from_ne_bytes(buf[i + 4..i + 8].try_into().unwrap()); + let v0 = u32::from_ne_bytes([buf[i], buf[i + 1], buf[i + 2], buf[i + 3]]); + let v1 = u32::from_ne_bytes([buf[i + 4], buf[i + 5], buf[i + 6], buf[i + 7]]); s0 = s0.wrapping_add(v0.wrapping_add(s1)); s1 = s1.wrapping_add(v1.wrapping_add(s0)); i += 8; } } else { while i < buf.len() { - let v0 = u32::from_ne_bytes(buf[i..i + 4].try_into().unwrap()).swap_bytes(); - let v1 = u32::from_ne_bytes(buf[i + 4..i + 8].try_into().unwrap()).swap_bytes(); + let v0 = u32::from_ne_bytes([buf[i], buf[i + 1], buf[i + 2], buf[i + 3]]).swap_bytes(); + let v1 = + u32::from_ne_bytes([buf[i + 4], buf[i + 5], buf[i + 6], buf[i + 7]]).swap_bytes(); s0 = s0.wrapping_add(v0.wrapping_add(s1)); s1 = s1.wrapping_add(v1.wrapping_add(s0)); i += 8; diff --git a/core/storage/wal.rs b/core/storage/wal.rs index 018ad16aaa..b039c1ecd3 100644 --- a/core/storage/wal.rs +++ b/core/storage/wal.rs @@ -1031,7 +1031,7 @@ impl Wal for WalFile { self.with_shared(|shared| { let nbackfills = shared.nbackfills.load(Ordering::Acquire); turso_assert!( - frame_watermark.is_none() || frame_watermark.unwrap() >= nbackfills, + frame_watermark.map_or(true, |fw| fw >= nbackfills), "frame_watermark must be >= than current WAL backfill amount: frame_watermark={:?}, nBackfill={}", frame_watermark, nbackfills ); }); @@ -1126,7 +1126,11 @@ impl Wal for WalFile { // This means that readers trying to acquire the lock will block even if the lock is unlocked // when there are writers waiting to acquire the lock. // Because of this, attempts to recursively acquire a read lock within a single thread may result in a deadlock." - shared.file.as_ref().unwrap().clone() + shared + .file + .as_ref() + .expect("WAL file must be initialized") + .clone() }); begin_read_wal_frame( file.as_ref(), @@ -1194,7 +1198,11 @@ impl Wal for WalFile { }); let file = self.with_shared(|shared| { assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); - shared.file.as_ref().unwrap().clone() + shared + .file + .as_ref() + .expect("WAL file must be initialized") + .clone() }); let c = begin_read_wal_frame_raw(&self.buffer_pool, file.as_ref(), offset, complete)?; Ok(c) @@ -1260,7 +1268,11 @@ impl Wal for WalFile { }); let file = self.with_shared(|shared| { assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); - shared.file.as_ref().unwrap().clone() + shared + .file + .as_ref() + .expect("WAL file must be initialized") + .clone() }); let c = begin_read_wal_frame( file.as_ref(), @@ -1285,7 +1297,11 @@ impl Wal for WalFile { let (header, file) = self.with_shared(|shared| { let header = shared.wal_header.clone(); assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); - let file = shared.file.as_ref().unwrap().clone(); + let file = shared + .file + .as_ref() + .expect("WAL file must be initialized") + .clone(); (header, file) }); let header = header.lock(); @@ -1340,7 +1356,11 @@ impl Wal for WalFile { }); let file = self.with_shared(|shared| { assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); - shared.file.as_ref().unwrap().clone() + shared + .file + .as_ref() + .expect("WAL file must be initialized") + .clone() }); self.syncing.store(true, Ordering::SeqCst); let c = file.sync(completion)?; @@ -1465,7 +1485,11 @@ impl Wal for WalFile { assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); ( *shared.wal_header.lock(), - shared.file.as_ref().unwrap().clone(), + shared + .file + .as_ref() + .expect("WAL file must be initialized") + .clone(), ) }); let c = sqlite3_ondisk::begin_write_wal_header(file.as_ref(), &header)?; @@ -1475,7 +1499,11 @@ impl Wal for WalFile { fn prepare_wal_finish(&mut self) -> Result { let file = self.with_shared(|shared| { assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); - shared.file.as_ref().unwrap().clone() + shared + .file + .as_ref() + .expect("WAL file must be initialized") + .clone() }); let shared = self.shared.clone(); let c = file.sync(Completion::new_sync(move |_| { @@ -1598,7 +1626,11 @@ impl Wal for WalFile { let file = self.with_shared(|shared| { assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); - shared.file.as_ref().unwrap().clone() + shared + .file + .as_ref() + .expect("WAL file must be initialized") + .clone() }); let c = file.pwritev(start_off, iovecs, c)?; Ok(c) @@ -2035,7 +2067,9 @@ impl WalFile { else { panic!("unxpected state"); }; - checkpoint_result.take().unwrap() + checkpoint_result + .take() + .expect("checkpoint result must be present in Truncate state") }; // increment wal epoch to ensure no stale pages are used for backfilling self.with_shared(|shared| shared.epoch.fetch_add(1, Ordering::Release)); @@ -2178,7 +2212,11 @@ impl WalFile { let file = self.with_shared(|shared| { assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); shared.initialized.store(false, Ordering::Release); - shared.file.as_ref().unwrap().clone() + shared + .file + .as_ref() + .expect("WAL file must be initialized") + .clone() }); let CheckpointState::Truncate { @@ -2288,7 +2326,11 @@ impl WalFile { // schedule read of the page payload let file = self.with_shared(|shared| { assert!(shared.enabled.load(Ordering::SeqCst), "WAL must be enabled"); - shared.file.as_ref().unwrap().clone() + shared + .file + .as_ref() + .expect("WAL file must be initialized") + .clone() }); let c = begin_read_wal_frame( file.as_ref(), diff --git a/core/time/internal.rs b/core/time/internal.rs index 15548ac69b..c9f4fdb58e 100644 --- a/core/time/internal.rs +++ b/core/time/internal.rs @@ -200,9 +200,9 @@ impl Time { offset: FixedOffset, ) -> Result { let mut dt: NaiveDateTime = NaiveDate::from_ymd_opt(1, 1, 1) - .unwrap() + .ok_or(TimeError::CreationError)? .and_hms_opt(0, 0, 0) - .unwrap(); + .ok_or(TimeError::CreationError)?; match year.cmp(&0) { std::cmp::Ordering::Greater => { @@ -338,7 +338,7 @@ impl Time { let mut minutes: i64 = 0; let mut seconds: i64 = 0; let mut nano_secs: i64 = 0; - let offset = FixedOffset::east_opt(0).unwrap(); // UTC + let offset = FixedOffset::east_opt(0).ok_or(TimeError::InvalidOffset)?; // UTC match field { Millennium => { diff --git a/core/time/mod.rs b/core/time/mod.rs index 05ddddc639..dd610a4397 100644 --- a/core/time/mod.rs +++ b/core/time/mod.rs @@ -191,6 +191,8 @@ fn time_date_internal(args: &[Value]) -> Value { seconds -= offset_sec as i64; } + let offset = ok_tri!(FixedOffset::east_opt(0), "failed to create UTC offset"); + let t = Time::time_date( year as i32, month as i32, @@ -199,7 +201,7 @@ fn time_date_internal(args: &[Value]) -> Value { minutes, seconds, nano_secs, - FixedOffset::east_opt(0).unwrap(), + offset, ); let t = tri!(t); @@ -632,7 +634,10 @@ fn dur_ns(args: &[Value]) -> Value { return Value::error(ResultCode::InvalidArgs); } - Value::from_integer(chrono::Duration::nanoseconds(1).num_nanoseconds().unwrap()) + Value::from_integer(ok_tri!( + chrono::Duration::nanoseconds(1).num_nanoseconds(), + "failed to convert nanoseconds duration" + )) } /// 1 microsecond @@ -642,7 +647,10 @@ fn dur_us(args: &[Value]) -> Value { return Value::error(ResultCode::InvalidArgs); } - Value::from_integer(chrono::Duration::microseconds(1).num_nanoseconds().unwrap()) + Value::from_integer(ok_tri!( + chrono::Duration::microseconds(1).num_nanoseconds(), + "failed to convert microseconds duration" + )) } /// 1 millisecond @@ -652,7 +660,10 @@ fn dur_ms(args: &[Value]) -> Value { return Value::error(ResultCode::InvalidArgs); } - Value::from_integer(chrono::Duration::milliseconds(1).num_nanoseconds().unwrap()) + Value::from_integer(ok_tri!( + chrono::Duration::milliseconds(1).num_nanoseconds(), + "failed to convert milliseconds duration" + )) } /// 1 second @@ -662,7 +673,10 @@ fn dur_s(args: &[Value]) -> Value { return Value::error(ResultCode::InvalidArgs); } - Value::from_integer(chrono::Duration::seconds(1).num_nanoseconds().unwrap()) + Value::from_integer(ok_tri!( + chrono::Duration::seconds(1).num_nanoseconds(), + "failed to convert seconds duration" + )) } /// 1 minute @@ -672,7 +686,10 @@ fn dur_m(args: &[Value]) -> Value { return Value::error(ResultCode::InvalidArgs); } - Value::from_integer(chrono::Duration::minutes(1).num_nanoseconds().unwrap()) + Value::from_integer(ok_tri!( + chrono::Duration::minutes(1).num_nanoseconds(), + "failed to convert minutes duration" + )) } /// 1 hour @@ -682,7 +699,10 @@ fn dur_h(args: &[Value]) -> Value { return Value::error(ResultCode::InvalidArgs); } - Value::from_integer(chrono::Duration::hours(1).num_nanoseconds().unwrap()) + Value::from_integer(ok_tri!( + chrono::Duration::hours(1).num_nanoseconds(), + "failed to convert hours duration" + )) } // Time Arithmetic @@ -992,19 +1012,13 @@ fn time_parse(args: &[Value]) -> Value { } if let Ok(mut dt) = chrono::NaiveDateTime::parse_from_str(dt_str, "%Y-%m-%d %H:%M:%S") { - // Unwrap is safe here - dt = dt.with_nanosecond(0).unwrap(); + dt = ok_tri!(dt.with_nanosecond(0), "failed to set nanosecond to 0"); return Time::from_datetime(dt.and_utc()).into_blob(); } if let Ok(date) = chrono::NaiveDate::parse_from_str(dt_str, "%Y-%m-%d") { - // Unwrap is safe here - - let dt = date - .and_hms_opt(0, 0, 0) - .unwrap() - .with_nanosecond(0) - .unwrap(); + let dt = ok_tri!(date.and_hms_opt(0, 0, 0), "failed to create time 00:00:00"); + let dt = ok_tri!(dt.with_nanosecond(0), "failed to set nanosecond to 0"); return Time::from_datetime(dt.and_utc()).into_blob(); } @@ -1012,9 +1026,12 @@ fn time_parse(args: &[Value]) -> Value { chrono::NaiveTime::parse_from_str(dt_str, "%H:%M:%S"), "error parsing datetime string" ); - let dt = NaiveDateTime::new(NaiveDate::from_ymd_opt(1, 1, 1).unwrap(), time) - .with_nanosecond(0) - .unwrap(); + let date = ok_tri!( + NaiveDate::from_ymd_opt(1, 1, 1), + "failed to create baseline date" + ); + let dt = NaiveDateTime::new(date, time); + let dt = ok_tri!(dt.with_nanosecond(0), "failed to set nanosecond to 0"); Time::from_datetime(dt.and_utc()).into_blob() } diff --git a/core/translate/display.rs b/core/translate/display.rs index e2c5e804ca..cac87cb687 100644 --- a/core/translate/display.rs +++ b/core/translate/display.rs @@ -143,13 +143,14 @@ impl Display for SelectPlan { } }, Operation::IndexMethodQuery(query) => { - let index_method = query.index.index_method.as_ref().unwrap(); - writeln!( - f, - "{}QUERY INDEX METHOD {}", - indent, - index_method.definition().method_name - )?; + if let Some(index_method) = query.index.index_method.as_ref() { + writeln!( + f, + "{}QUERY INDEX METHOD {}", + indent, + index_method.definition().method_name + )?; + } } } } @@ -217,13 +218,14 @@ impl Display for DeletePlan { } }, Operation::IndexMethodQuery(query) => { - let module = query.index.index_method.as_ref().unwrap(); - writeln!( - f, - "{}QUERY MODULE {}", - indent, - module.definition().method_name - )?; + if let Some(module) = query.index.index_method.as_ref() { + writeln!( + f, + "{}QUERY MODULE {}", + indent, + module.definition().method_name + )?; + } } } } @@ -304,13 +306,14 @@ impl fmt::Display for UpdatePlan { } }, Operation::IndexMethodQuery(query) => { - let module = query.index.index_method.as_ref().unwrap(); - writeln!( - f, - "{}QUERY MODULE {}", - indent, - module.definition().method_name - )?; + if let Some(module) = query.index.index_method.as_ref() { + writeln!( + f, + "{}QUERY MODULE {}", + indent, + module.definition().method_name + )?; + } } } } @@ -517,7 +520,10 @@ impl ToTokens for SelectPlan { s.append(TokenType::TK_JOIN, None)?; } - let table_ref = self.joined_tables().get(order.original_idx).unwrap(); + let table_ref = self + .joined_tables() + .get(order.original_idx) + .expect("table reference not found for join order"); table_ref.to_tokens(s, context)?; } @@ -701,14 +707,11 @@ impl ToTokens for UpdatePlan { s.comma( self.set_clauses.iter().map(|(col_idx, set_expr)| { - let col_name = table + let col = table .table .get_column_at(*col_idx) - .as_ref() - .unwrap() - .name - .as_ref() - .unwrap(); + .unwrap_or_else(|| panic!("column at index {} not found", col_idx)); + let col_name = col.name.as_ref().expect("column must have a name"); ast::Set { col_names: vec![ast::Name::exact(col_name.clone())], diff --git a/core/translate/emitter.rs b/core/translate/emitter.rs index e157589aa7..3d442df79c 100644 --- a/core/translate/emitter.rs +++ b/core/translate/emitter.rs @@ -399,7 +399,9 @@ pub fn emit_query<'a>( if plan.is_simple_count() { emit_simple_count(program, t_ctx, plan)?; - return Ok(t_ctx.reg_result_cols_start.unwrap()); + return t_ctx.reg_result_cols_start.ok_or_else(|| { + crate::LimboError::InternalError("reg_result_cols_start was not set".to_string()) + }); } // Set up main query execution loop @@ -437,7 +439,9 @@ pub fn emit_query<'a>( let row_source = &t_ctx .meta_group_by .as_ref() - .expect("group by metadata not found") + .ok_or_else(|| { + crate::LimboError::InternalError("group by metadata not found".to_string()) + })? .row_source; if matches!(row_source, GroupByRowSource::Sorter { .. }) { group_by_agg_phase(program, t_ctx, plan)?; @@ -457,7 +461,9 @@ pub fn emit_query<'a>( emit_order_by(program, t_ctx, plan)?; } - Ok(t_ctx.reg_result_cols_start.unwrap()) + t_ctx.reg_result_cols_start.ok_or_else(|| { + crate::LimboError::InternalError("reg_result_cols_start was not set".to_string()) + }) } #[instrument(skip_all, level = Level::DEBUG)] @@ -514,9 +520,11 @@ fn emit_program_for_delete( // If there's a rowset_plan, materialize rowids into a RowSet first and then iterate the RowSet // to delete the rows. if let Some(rowset_plan) = plan.rowset_plan.take() { - let rowset_reg = plan - .rowset_reg - .expect("rowset_reg must be Some if rowset_plan is Some"); + let rowset_reg = plan.rowset_reg.ok_or_else(|| { + crate::LimboError::InternalError( + "rowset_reg must be Some if rowset_plan is Some".to_string(), + ) + })?; // Initialize the RowSet register with NULL (RowSet will be created on first RowSetAdd) program.emit_insn(Insn::Null { @@ -530,7 +538,13 @@ fn emit_program_for_delete( program.decr_nesting(); // Close the read cursor(s) opened by the rowset plan before opening for writing - let table_ref = plan.table_references.joined_tables().first().unwrap(); + let table_ref = plan + .table_references + .joined_tables() + .first() + .ok_or_else(|| { + crate::LimboError::InternalError("plan has at least one table".to_string()) + })?; let table_cursor_id_read = program.resolve_cursor_id(&CursorKey::table(table_ref.internal_id)); program.emit_insn(Insn::Close { @@ -655,7 +669,9 @@ pub fn emit_fk_child_decrement_on_delete( // Fast path: if any FK column is NULL can't be a violation let null_skip = program.allocate_label(); for cname in &fk_ref.child_cols { - let (pos, col) = child_tbl.get_column(cname).unwrap(); + let (pos, col) = child_tbl.get_column(cname).ok_or_else(|| { + crate::LimboError::InternalError(format!("foreign key column {} not found", cname)) + })?; let src = if col.is_rowid_alias() { child_rowid_reg } else { @@ -679,10 +695,20 @@ pub fn emit_fk_child_decrement_on_delete( let parent_tbl = resolver .schema .get_btree_table(&fk_ref.fk.parent_table) - .expect("parent btree"); + .ok_or_else(|| { + crate::LimboError::InternalError(format!( + "parent btree table {} not found", + fk_ref.fk.parent_table + )) + })?; let pcur = open_read_table(program, &parent_tbl); - let (pos, col) = child_tbl.get_column(&fk_ref.child_cols[0]).unwrap(); + let (pos, col) = child_tbl.get_column(&fk_ref.child_cols[0]).ok_or_else(|| { + crate::LimboError::InternalError(format!( + "foreign key column {} not found", + &fk_ref.child_cols[0] + )) + })?; let val = if col.is_rowid_alias() { child_rowid_reg } else { @@ -727,15 +753,29 @@ pub fn emit_fk_child_decrement_on_delete( let parent_tbl = resolver .schema .get_btree_table(&fk_ref.fk.parent_table) - .expect("parent btree"); - let idx = fk_ref.parent_unique_index.as_ref().expect("unique index"); + .ok_or_else(|| { + crate::LimboError::InternalError(format!( + "parent btree table {} not found", + fk_ref.fk.parent_table + )) + })?; + let idx = fk_ref.parent_unique_index.as_ref().ok_or_else(|| { + crate::LimboError::InternalError( + "unique index not found for foreign key".to_string(), + ) + })?; let icur = open_read_index(program, idx); // Build probe from current child row let n = fk_ref.child_cols.len(); let probe = program.alloc_registers(n); for (i, cname) in fk_ref.child_cols.iter().enumerate() { - let (pos, col) = child_tbl.get_column(cname).unwrap(); + let (pos, col) = child_tbl.get_column(cname).ok_or_else(|| { + crate::LimboError::InternalError(format!( + "foreign key column {} not found", + cname + )) + })?; let src = if col.is_rowid_alias() { child_rowid_reg } else { @@ -756,7 +796,9 @@ pub fn emit_fk_child_decrement_on_delete( } program.emit_insn(Insn::Affinity { start_reg: probe, - count: std::num::NonZeroUsize::new(n).unwrap(), + count: std::num::NonZeroUsize::new(n).ok_or_else(|| { + crate::LimboError::InternalError("foreign key column count is zero".to_string()) + })?, affinities: build_index_affinity_string(idx, &parent_tbl), }); @@ -785,7 +827,10 @@ fn emit_delete_insns<'a>( result_columns: &'a [super::plan::ResultSetColumn], ) -> Result<()> { // we can either use this obviously safe raw pointer or we can clone it - let table_reference: *const JoinedTable = table_references.joined_tables().first().unwrap(); + let table_reference: *const JoinedTable = + table_references.joined_tables().first().ok_or_else(|| { + crate::LimboError::InternalError("plan has at least one table".to_string()) + })?; if unsafe { &*table_reference } .virtual_table() .is_some_and(|t| t.readonly()) @@ -883,7 +928,9 @@ fn emit_delete_insns<'a>( if let Some(limit_ctx) = t_ctx.limit_ctx { program.emit_insn(Insn::DecrJumpZero { reg: limit_ctx.reg_limit, - target_pc: t_ctx.label_main_loop_end.unwrap(), + target_pc: t_ctx.label_main_loop_end.ok_or_else(|| { + crate::LimboError::InternalError("label_main_loop_end was not set".to_string()) + })?, }) } @@ -975,7 +1022,11 @@ fn emit_delete_row_common( let skip_delete_label = if index.where_clause.is_some() { let where_copy = index .bind_where_expr(Some(table_references), connection) - .expect("where clause to exist"); + .ok_or_else(|| { + crate::LimboError::InternalError( + "where clause not found for partial index".to_string(), + ) + })?; let skip_label = program.allocate_label(); let reg = program.alloc_register(); translate_expr_no_constant_opt( @@ -1051,8 +1102,12 @@ fn emit_delete_row_common( // Emit RETURNING results if specified (must be before DELETE) if !result_columns.is_empty() { - let columns_start_reg = columns_start_reg - .expect("columns_start_reg must be provided when there are triggers or RETURNING"); + let columns_start_reg = columns_start_reg.ok_or_else(|| { + crate::LimboError::InternalError( + "columns_start_reg must be provided when there are triggers or RETURNING" + .to_string(), + ) + })?; // Emit RETURNING results using the values we just read emit_returning_results( program, @@ -1095,7 +1150,10 @@ fn emit_delete_insns_when_triggers_present( target_pc: skip_not_found_label, }); - let table_reference: *const JoinedTable = table_references.joined_tables().first().unwrap(); + let table_reference: *const JoinedTable = + table_references.joined_tables().first().ok_or_else(|| { + crate::LimboError::InternalError("plan has at least one table".to_string()) + })?; if unsafe { &*table_reference } .virtual_table() .is_some_and(|t| t.readonly()) @@ -1139,8 +1197,12 @@ fn emit_delete_insns_when_triggers_present( ); let has_relevant_triggers = relevant_triggers.clone().count() > 0; if has_relevant_triggers { - let columns_start_reg = columns_start_reg - .expect("columns_start_reg must be provided when there are triggers or RETURNING"); + let columns_start_reg = columns_start_reg.ok_or_else(|| { + crate::LimboError::InternalError( + "columns_start_reg must be provided when there are triggers or RETURNING" + .to_string(), + ) + })?; let old_registers = (0..cols_len) .map(|i| columns_start_reg + i) .chain(std::iter::once(rowid_reg)) @@ -1195,8 +1257,12 @@ fn emit_delete_insns_when_triggers_present( ); let has_relevant_triggers = relevant_triggers.clone().count() > 0; if has_relevant_triggers { - let columns_start_reg = columns_start_reg - .expect("columns_start_reg must be provided when there are triggers or RETURNING"); + let columns_start_reg = columns_start_reg.ok_or_else(|| { + crate::LimboError::InternalError( + "columns_start_reg must be provided when there are triggers or RETURNING" + .to_string(), + ) + })?; let old_registers = (0..cols_len) .map(|i| columns_start_reg + i) .chain(std::iter::once(rowid_reg)) @@ -1265,10 +1331,14 @@ fn emit_program_for_update( .table_references .joined_tables() .first() - .unwrap() + .ok_or_else(|| { + crate::LimboError::InternalError("no joined tables in ephemeral plan".to_string()) + })? .clone(); program.emit_insn(Insn::OpenEphemeral { - cursor_id: temp_cursor_id.unwrap(), + cursor_id: temp_cursor_id.ok_or_else(|| { + crate::LimboError::InternalError("temp_cursor_id was not set".to_string()) + })?, is_table: true, }); program.incr_nesting(); @@ -1280,16 +1350,21 @@ fn emit_program_for_update( plan.table_references .joined_tables() .first() - .unwrap() + .ok_or_else(|| { + crate::LimboError::InternalError("no joined tables in plan".to_string()) + })? .clone(), ) }; let mode = OperationMode::UPDATE(if has_ephemeral_table { UpdateRowSource::PrebuiltEphemeralTable { - ephemeral_table_cursor_id: temp_cursor_id.expect( - "ephemeral table cursor id is always allocated if has_ephemeral_table is true", - ), + ephemeral_table_cursor_id: temp_cursor_id.ok_or_else(|| { + crate::LimboError::InternalError( + "ephemeral table cursor id must be allocated when has_ephemeral_table is true" + .to_string(), + ) + })?, target_table: target_table.clone(), } } else { @@ -1324,14 +1399,16 @@ fn emit_program_for_update( // Prepare index cursors let mut index_cursors = Vec::with_capacity(plan.indexes_to_update.len()); for index in &plan.indexes_to_update { - let index_cursor = if let Some(cursor) = program.resolve_cursor_id_safe(&CursorKey::index( - plan.table_references - .joined_tables() - .first() - .unwrap() - .internal_id, - index.clone(), - )) { + let first_table = plan + .table_references + .joined_tables() + .first() + .ok_or_else(|| { + crate::LimboError::InternalError("no joined tables in plan".to_string()) + })?; + let index_cursor = if let Some(cursor) = program + .resolve_cursor_id_safe(&CursorKey::index(first_table.internal_id, index.clone())) + { cursor } else { let cursor = program.alloc_cursor_index(None, index)?; @@ -1362,7 +1439,9 @@ fn emit_program_for_update( program.resolve_cursor_id(&CursorKey::table(target_table.internal_id)); let iteration_cursor_id = if has_ephemeral_table { - temp_cursor_id.unwrap() + temp_cursor_id.ok_or_else(|| { + crate::LimboError::InternalError("temp_cursor_id was not set".to_string()) + })? } else { target_table_cursor_id }; @@ -1428,7 +1507,9 @@ fn emit_update_column_values<'a>( if has_direct_rowid_update { if let Some((_, expr)) = set_clauses.iter().find(|(i, _)| *i == ROWID_SENTINEL) { if !skip_set_clauses { - let rowid_set_clause_reg = rowid_set_clause_reg.unwrap(); + let rowid_set_clause_reg = rowid_set_clause_reg.ok_or_else(|| { + crate::LimboError::InternalError("rowid_set_clause_reg was not set".to_string()) + })?; translate_expr( program, Some(table_references), @@ -1454,7 +1535,11 @@ fn emit_update_column_values<'a>( && (table_column.primary_key() || table_column.is_rowid_alias()) && !is_virtual { - let rowid_set_clause_reg = rowid_set_clause_reg.unwrap(); + let rowid_set_clause_reg = rowid_set_clause_reg.ok_or_else(|| { + crate::LimboError::InternalError( + "rowid_set_clause_reg was not set".to_string(), + ) + })?; translate_expr( program, Some(table_references), @@ -1484,10 +1569,11 @@ fn emit_update_column_values<'a>( description: format!( "{}.{}", table_name, - table_column - .name - .as_ref() - .expect("Column name must be present") + table_column.name.as_ref().ok_or_else(|| { + crate::LimboError::InternalError( + "Column name must be present".to_string(), + ) + })? ), }); } @@ -1589,8 +1675,13 @@ fn emit_update_insns<'a>( target_table: Arc, ) -> crate::Result<()> { let internal_id = target_table.internal_id; - let loop_labels = t_ctx.labels_main_loop.first().unwrap(); - let source_table = table_references.joined_tables().first().unwrap(); + let loop_labels = t_ctx.labels_main_loop.first().ok_or_else(|| { + crate::LimboError::InternalError("main loop labels were not set".to_string()) + })?; + let source_table = table_references + .joined_tables() + .first() + .ok_or_else(|| crate::LimboError::InternalError("no joined tables in plan".to_string()))?; let (index, is_virtual) = match &source_table.op { Operation::Scan(Scan::BTreeTable { index, .. }) => ( index.as_ref().map(|index| { @@ -1666,13 +1757,19 @@ fn emit_update_insns<'a>( program.emit_insn(Insn::NotExists { cursor: target_table_cursor_id, rowid_reg: beg, - target_pc: check_rowid_not_exists_label.unwrap(), + target_pc: check_rowid_not_exists_label.ok_or_else(|| { + crate::LimboError::InternalError( + "check_rowid_not_exists_label was not set".to_string(), + ) + })?, }); } else { // if no rowid, we're done program.emit_insn(Insn::IsNull { reg: beg, - target_pc: t_ctx.label_main_loop_end.unwrap(), + target_pc: t_ctx.label_main_loop_end.ok_or_else(|| { + crate::LimboError::InternalError("label_main_loop_end was not set".to_string()) + })?, }); } @@ -1730,94 +1827,96 @@ fn emit_update_insns<'a>( )?; // Fire BEFORE UPDATE triggers and preserve old_registers for AFTER triggers - let preserved_old_registers: Option> = - if let Some(btree_table) = target_table.table.btree() { - let updated_column_indices: std::collections::HashSet = - set_clauses.iter().map(|(col_idx, _)| *col_idx).collect(); - let relevant_before_update_triggers = get_relevant_triggers_type_and_time( - t_ctx.resolver.schema, - TriggerEvent::Update, - TriggerTime::Before, - Some(updated_column_indices.clone()), - &btree_table, - ); - // Read OLD row values for trigger context - let old_registers: Vec = (0..col_len) - .map(|i| { - let reg = program.alloc_register(); - program.emit_column_or_rowid(target_table_cursor_id, i, reg); - reg - }) + let preserved_old_registers: Option> = if let Some(btree_table) = + target_table.table.btree() + { + let updated_column_indices: std::collections::HashSet = + set_clauses.iter().map(|(col_idx, _)| *col_idx).collect(); + let relevant_before_update_triggers = get_relevant_triggers_type_and_time( + t_ctx.resolver.schema, + TriggerEvent::Update, + TriggerTime::Before, + Some(updated_column_indices.clone()), + &btree_table, + ); + // Read OLD row values for trigger context + let old_registers: Vec = (0..col_len) + .map(|i| { + let reg = program.alloc_register(); + program.emit_column_or_rowid(target_table_cursor_id, i, reg); + reg + }) + .chain(std::iter::once(beg)) + .collect(); + let has_relevant_triggers = relevant_before_update_triggers.clone().count() > 0; + if !has_relevant_triggers { + Some(old_registers) + } else { + // NEW row values are already in 'start' registers + let new_registers = (0..col_len) + .map(|i| start + i) .chain(std::iter::once(beg)) .collect(); - let has_relevant_triggers = relevant_before_update_triggers.clone().count() > 0; - if !has_relevant_triggers { - Some(old_registers) - } else { - // NEW row values are already in 'start' registers - let new_registers = (0..col_len) - .map(|i| start + i) - .chain(std::iter::once(beg)) - .collect(); - let trigger_ctx = TriggerContext::new( - btree_table.clone(), - Some(new_registers), - Some(old_registers.clone()), // Clone for AFTER trigger - ); + let trigger_ctx = TriggerContext::new( + btree_table.clone(), + Some(new_registers), + Some(old_registers.clone()), // Clone for AFTER trigger + ); - for trigger in relevant_before_update_triggers { - fire_trigger( - program, - &mut t_ctx.resolver, - trigger, - &trigger_ctx, - connection, - )?; - } + for trigger in relevant_before_update_triggers { + fire_trigger( + program, + &mut t_ctx.resolver, + trigger, + &trigger_ctx, + connection, + )?; + } - // BEFORE UPDATE Triggers may have altered the btree so we need to seek again. - program.emit_insn(Insn::NotExists { + // BEFORE UPDATE Triggers may have altered the btree so we need to seek again. + program.emit_insn(Insn::NotExists { cursor: target_table_cursor_id, rowid_reg: beg, - target_pc: check_rowid_not_exists_label.expect( - "check_rowid_not_exists_label must be set if there are BEFORE UPDATE triggers", - ), + target_pc: check_rowid_not_exists_label + .ok_or_else(|| crate::LimboError::InternalError( + "check_rowid_not_exists_label must be set if there are BEFORE UPDATE triggers".to_string(), + ))?, }); - let has_relevant_after_triggers = get_relevant_triggers_type_and_time( - t_ctx.resolver.schema, - TriggerEvent::Update, - TriggerTime::After, - Some(updated_column_indices), - &btree_table, - ) - .clone() - .count() - > 0; - if has_relevant_after_triggers { - // Preserve pseudo-row 'OLD' for AFTER triggers by copying to new registers - // (since registers might be overwritten during trigger execution) - let preserved: Vec = old_registers - .iter() - .map(|old_reg| { - let preserved_reg = program.alloc_register(); - program.emit_insn(Insn::Copy { - src_reg: *old_reg, - dst_reg: preserved_reg, - extra_amount: 0, - }); - preserved_reg - }) - .collect(); - Some(preserved) - } else { - Some(old_registers) - } + let has_relevant_after_triggers = get_relevant_triggers_type_and_time( + t_ctx.resolver.schema, + TriggerEvent::Update, + TriggerTime::After, + Some(updated_column_indices), + &btree_table, + ) + .clone() + .count() + > 0; + if has_relevant_after_triggers { + // Preserve pseudo-row 'OLD' for AFTER triggers by copying to new registers + // (since registers might be overwritten during trigger execution) + let preserved: Vec = old_registers + .iter() + .map(|old_reg| { + let preserved_reg = program.alloc_register(); + program.emit_insn(Insn::Copy { + src_reg: *old_reg, + dst_reg: preserved_reg, + extra_amount: 0, + }); + preserved_reg + }) + .collect(); + Some(preserved) + } else { + Some(old_registers) } - } else { - None - }; + } + } else { + None + }; // If BEFORE UPDATE triggers fired, they may have modified the row being updated. // According to the SQLite documentation, the behavior in these cases is undefined: @@ -1920,7 +2019,11 @@ fn emit_update_insns<'a>( // so we can emit Insn::Column instructions and refer to the old values. let where_clause = index .bind_where_expr(Some(table_references), connection) - .expect("where clause to exist"); + .ok_or_else(|| { + crate::LimboError::InternalError( + "where clause not found for partial index".to_string(), + ) + })?; let old_satisfied_reg = program.alloc_register(); translate_expr_no_constant_opt( program, @@ -1935,7 +2038,11 @@ fn emit_update_insns<'a>( let mut new_where = index .where_clause .as_ref() - .expect("checked where clause to exist") + .ok_or_else(|| { + crate::LimboError::InternalError( + "where clause not found for partial index".to_string(), + ) + })? .clone(); // Now we need to rewrite the Expr::Id and Expr::Qualified/Expr::RowID (from a copy of the original, un-bound `where` expr), // to refer to the new values, which are already loaded into registers starting at `start`. @@ -1973,7 +2080,9 @@ fn emit_update_insns<'a>( // If the old values don't satisfy the WHERE clause, skip the delete program.emit_insn(Insn::IfNot { reg: old_satisfied, - target_pc: skip_delete_label.unwrap(), + target_pc: skip_delete_label.ok_or_else(|| { + crate::LimboError::InternalError("skip_delete_label was not set".to_string()) + })?, jump_if_null: true, }); } @@ -2014,7 +2123,9 @@ fn emit_update_insns<'a>( // If the new values don't satisfy the WHERE clause, skip the idx insert program.emit_insn(Insn::IfNot { reg: new_satisfied, - target_pc: skip_insert_label.unwrap(), + target_pc: skip_insert_label.ok_or_else(|| { + crate::LimboError::InternalError("skip_insert_label was not set".to_string()) + })?, jump_if_null: true, }); } @@ -2067,7 +2178,9 @@ fn emit_update_insns<'a>( .collect::(); program.emit_insn(Insn::Affinity { start_reg: idx_start_reg, - count: NonZeroUsize::new(num_cols).expect("nonzero col count"), + count: NonZeroUsize::new(num_cols).ok_or_else(|| { + crate::LimboError::InternalError("index column count is zero".to_string()) + })?, affinities: aff, }); let constraint_check = program.allocate_label(); @@ -2142,7 +2255,9 @@ fn emit_update_insns<'a>( if has_user_provided_rowid { let record_label = program.allocate_label(); - let target_reg = rowid_set_clause_reg.unwrap(); + let target_reg = rowid_set_clause_reg.ok_or_else(|| { + crate::LimboError::InternalError("rowid_set_clause_reg was not set".to_string()) + })?; program.emit_insn(Insn::Eq { lhs: target_reg, @@ -2165,7 +2280,12 @@ fn emit_update_insns<'a>( .table .columns() .get(idx) - .unwrap() + .ok_or_else(|| { + crate::LimboError::InternalError(format!( + "column index {} out of bounds", + idx + )) + })? .name .as_ref() .map_or("", |v| v) @@ -2202,7 +2322,11 @@ fn emit_update_insns<'a>( program.emit_insn(Insn::NotExists { cursor: target_table_cursor_id, rowid_reg: beg, - target_pc: check_rowid_not_exists_label.unwrap(), + target_pc: check_rowid_not_exists_label.ok_or_else(|| { + crate::LimboError::InternalError( + "check_rowid_not_exists_label was not set".to_string(), + ) + })?, }); } @@ -2231,7 +2355,9 @@ fn emit_update_insns<'a>( program, target_table.table.columns(), target_table_cursor_id, - cdc_rowid_before_reg.expect("cdc_rowid_before_reg must be set"), + cdc_rowid_before_reg.ok_or_else(|| { + crate::LimboError::InternalError("cdc_rowid_before_reg must be set".to_string()) + })?, )) } else { None @@ -2345,8 +2471,9 @@ fn emit_update_insns<'a>( // emit actual CDC instructions for write to the CDC table if let Some(cdc_cursor_id) = t_ctx.cdc_cursor_id { - let cdc_rowid_before_reg = - cdc_rowid_before_reg.expect("cdc_rowid_before_reg must be set"); + let cdc_rowid_before_reg = cdc_rowid_before_reg.ok_or_else(|| { + crate::LimboError::InternalError("cdc_rowid_before_reg must be set".to_string()) + })?; if has_user_provided_rowid { emit_cdc_insns( program, @@ -2404,7 +2531,9 @@ fn emit_update_insns<'a>( if let Some(limit_ctx) = t_ctx.limit_ctx { program.emit_insn(Insn::DecrJumpZero { reg: limit_ctx.reg_limit, - target_pc: t_ctx.label_main_loop_end.unwrap(), + target_pc: t_ctx.label_main_loop_end.ok_or_else(|| { + crate::LimboError::InternalError("label_main_loop_end was not set".to_string()) + })?, }) } // TODO(pthorpe): handle RETURNING clause @@ -2650,7 +2779,12 @@ fn init_limit( }); } else { program.emit_insn(Insn::Real { - value: n.parse::().unwrap(), + value: n.parse::().map_err(|_| { + crate::LimboError::InternalError(format!( + "invalid float literal: {}", + n + )) + })?, dest: limit_ctx.reg_limit, }); program.add_comment(program.offset(), "LIMIT counter"); @@ -2710,9 +2844,11 @@ fn init_limit( } // exit early if LIMIT 0 - let main_loop_end = t_ctx - .label_main_loop_end - .expect("label_main_loop_end must be set before init_limit"); + let main_loop_end = t_ctx.label_main_loop_end.ok_or_else(|| { + crate::LimboError::InternalError( + "label_main_loop_end must be set before init_limit".to_string(), + ) + })?; program.emit_insn(Insn::IfNot { reg: limit_ctx.reg_limit, target_pc: main_loop_end, @@ -2833,9 +2969,12 @@ fn emit_index_column_value_new_image( NoConstantOptReason::RegisterReuse, )?; } else { - let col_in_table = columns - .get(idx_col.pos_in_table) - .expect("column index out of bounds"); + let col_in_table = columns.get(idx_col.pos_in_table).ok_or_else(|| { + crate::LimboError::InternalError(format!( + "column index {} out of bounds", + idx_col.pos_in_table + )) + })?; let src_reg = if col_in_table.is_rowid_alias() { rowid_reg } else { diff --git a/core/translate/expr.rs b/core/translate/expr.rs index 68429ccffc..9e788b8e20 100644 --- a/core/translate/expr.rs +++ b/core/translate/expr.rs @@ -498,13 +498,19 @@ pub fn translate_condition_expr( if *not { // When IN is TRUE (match found), NOT IN should be FALSE - program.resolve_label(not_true_label.unwrap(), program.offset()); + let label = not_true_label.ok_or_else(|| { + crate::LimboError::InternalError("not_true_label not set".to_string()) + })?; + program.resolve_label(label, program.offset()); program.emit_insn(Insn::Goto { target_pc: jump_target_when_false, }); // When IN is FALSE (no match), NOT IN should be TRUE - program.resolve_label(not_false_label.unwrap(), program.offset()); + let label = not_false_label.ok_or_else(|| { + crate::LimboError::InternalError("not_false_label not set".to_string()) + })?; + program.resolve_label(label, program.offset()); program.emit_insn(Insn::Goto { target_pc: jump_target_when_true, }); @@ -687,7 +693,10 @@ pub fn translate_expr( value: 0, dest: target_register, }); - let lhs_columns = match unwrap_parens(lhs.as_ref().unwrap())? { + let lhs = lhs.as_ref().ok_or_else(|| { + crate::LimboError::InternalError("lhs is None".to_string()) + })?; + let lhs_columns = match unwrap_parens(lhs)? { ast::Expr::Parenthesized(exprs) => { exprs.iter().map(|e| e.as_ref()).collect() } @@ -703,7 +712,12 @@ pub fn translate_expr( lhs_column_regs_start + i, resolver, )?; - if !lhs_column.is_nonnull(referenced_tables.as_ref().unwrap()) { + let ref_tables = referenced_tables.as_ref().ok_or_else(|| { + crate::LimboError::InternalError( + "referenced_tables is None".to_string(), + ) + })?; + if !lhs_column.is_nonnull(ref_tables) { program.emit_insn(Insn::IsNull { reg: lhs_column_regs_start + i, target_pc: if *not_in { @@ -839,13 +853,10 @@ pub fn translate_expr( let base_reg = base.as_ref().map(|_| program.alloc_register()); let expr_reg = program.alloc_register(); if let Some(base_expr) = base { - translate_expr( - program, - referenced_tables, - base_expr, - base_reg.unwrap(), - resolver, - )?; + let reg = base_reg.ok_or_else(|| { + crate::LimboError::InternalError("base_reg is None".to_string()) + })?; + translate_expr(program, referenced_tables, base_expr, reg, resolver)?; }; for (when_expr, then_expr) in when_then_pairs { translate_expr_no_constant_opt( @@ -913,7 +924,9 @@ pub fn translate_expr( Ok(target_register) } ast::Expr::Cast { expr, type_name } => { - let type_name = type_name.as_ref().unwrap(); // TODO: why is this optional? + let type_name = type_name.as_ref().ok_or_else(|| { + crate::LimboError::ParseError("CAST requires a type name".to_string()) + })?; translate_expr(program, referenced_tables, expr, target_register, resolver)?; let type_affinity = Affinity::affinity(&type_name.name); program.emit_insn(Insn::Cast { @@ -950,8 +963,11 @@ pub fn translate_expr( crate::bail_parse_error!("unknown function {}", name.as_str()); } + let func = func_type.ok_or_else(|| { + crate::LimboError::ParseError(format!("unknown function {}", name.as_str())) + })?; let func_ctx = FuncCtx { - func: func_type.unwrap(), + func, arg_count: args_count, }; @@ -1278,9 +1294,12 @@ pub fn translate_expr( start_reg = Some(start_reg.unwrap_or(reg)); translate_expr(program, referenced_tables, arg, reg, resolver)?; } + let reg = start_reg.ok_or_else(|| { + crate::LimboError::InternalError("start_reg not set".to_string()) + })?; program.emit_insn(Insn::Function { constant_mask: 0, - start_reg: start_reg.unwrap(), + start_reg: reg, dest: target_register, func: func_ctx, }); @@ -1395,10 +1414,12 @@ pub fn translate_expr( } if args.len() % 2 != 0 { + let last_arg = + args.last().expect("args.len() % 2 != 0 implies non-empty"); translate_expr_no_constant_opt( program, referenced_tables, - args.last().unwrap(), + last_arg, target_register, resolver, NoConstantOptReason::RegisterReuse, @@ -2091,8 +2112,11 @@ pub fn translate_expr( crate::bail_parse_error!("unknown function {}", name.as_str()); } + let func = func_type.ok_or_else(|| { + crate::LimboError::ParseError(format!("unknown function {}", name.as_str())) + })?; let func_ctx = FuncCtx { - func: func_type.unwrap(), + func, arg_count: args_count, }; @@ -2146,9 +2170,11 @@ pub fn translate_expr( is_rowid_alias, } => { let (index, index_method, use_covering_index) = { - if let Some(table_reference) = referenced_tables - .unwrap() - .find_joined_table_by_internal_id(*table_ref_id) + let ref_tables = referenced_tables.ok_or_else(|| { + crate::LimboError::InternalError("referenced_tables is None".to_string()) + })?; + if let Some(table_reference) = + ref_tables.find_joined_table_by_internal_id(*table_ref_id) { ( table_reference.op.index(), @@ -2165,8 +2191,10 @@ pub fn translate_expr( }; let use_index_method = index_method.and_then(|m| m.covered_columns.get(column)); - let (is_from_outer_query_scope, table) = referenced_tables - .unwrap() + let ref_tables = referenced_tables.ok_or_else(|| { + crate::LimboError::InternalError("referenced_tables is None".to_string()) + })?; + let (is_from_outer_query_scope, table) = ref_tables .find_table_by_internal_id(*table_ref_id) .unwrap_or_else(|| { unreachable!( @@ -2217,8 +2245,13 @@ pub fn translate_expr( }; if let Some(custom_module_column) = use_index_method { + let cursor_id = index_cursor_id.ok_or_else(|| { + crate::LimboError::InternalError( + "index_cursor_id not set for index method".to_string(), + ) + })?; program.emit_column_or_rowid( - index_cursor_id.unwrap(), + cursor_id, *custom_module_column, target_register, ); @@ -2302,9 +2335,11 @@ pub fn translate_expr( table: table_ref_id, } => { let (index, use_covering_index) = { - if let Some(table_reference) = referenced_tables - .unwrap() - .find_joined_table_by_internal_id(*table_ref_id) + let ref_tables = referenced_tables.ok_or_else(|| { + crate::LimboError::InternalError("referenced_tables is None".to_string()) + })?; + if let Some(table_reference) = + ref_tables.find_joined_table_by_internal_id(*table_ref_id) { ( table_reference.op.index(), @@ -3397,7 +3432,10 @@ pub fn unwrap_parens(expr: &ast::Expr) -> Result<&ast::Expr> { match expr { ast::Expr::Column { .. } => Ok(expr), ast::Expr::Parenthesized(exprs) => match exprs.len() { - 1 => unwrap_parens(exprs.first().unwrap()), + 1 => { + let first = exprs.first().expect("checked length == 1"); + unwrap_parens(first) + } _ => Ok(expr), // If the expression is e.g. (x, y), as used in e.g. (x, y) IN (SELECT ...), return as is. }, _ => Ok(expr), @@ -3412,7 +3450,8 @@ pub fn unwrap_parens_owned(expr: ast::Expr) -> Result<(ast::Expr, usize)> { ast::Expr::Parenthesized(mut exprs) => match exprs.len() { 1 => { paren_count += 1; - let (expr, count) = unwrap_parens_owned(*exprs.pop().unwrap().clone())?; + let last = exprs.pop().expect("checked length == 1"); + let (expr, count) = unwrap_parens_owned(*last.clone())?; paren_count += count; Ok((expr, paren_count)) } @@ -3712,11 +3751,20 @@ pub fn bind_and_rewrite_expr<'a>( crate::bail_parse_error!("Column {} is ambiguous", id.as_str()); } } else { + let col_idx_val = col_idx + .expect("col_idx must be Some when match_result is None"); let col = - joined_table.table.columns().get(col_idx.unwrap()).unwrap(); + joined_table.table.columns().get(col_idx_val).ok_or_else( + || { + crate::LimboError::InternalError(format!( + "column index {} out of bounds", + col_idx_val + )) + }, + )?; match_result = Some(( joined_table.internal_id, - col_idx.unwrap(), + col_idx_val, col.is_rowid_alias(), )); } @@ -3755,10 +3803,19 @@ pub fn bind_and_rewrite_expr<'a>( if match_result.is_some() { crate::bail_parse_error!("Column {} is ambiguous", id.as_str()); } - let col = outer_ref.table.columns().get(col_idx.unwrap()).unwrap(); + let col_idx_val = + col_idx.expect("col_idx must be Some when is_some() is true"); + let col = outer_ref.table.columns().get(col_idx_val).ok_or_else( + || { + crate::LimboError::InternalError(format!( + "column index {} out of bounds", + col_idx_val + )) + }, + )?; match_result = Some(( outer_ref.internal_id, - col_idx.unwrap(), + col_idx_val, col.is_rowid_alias(), )); } @@ -3818,7 +3875,8 @@ pub fn bind_and_rewrite_expr<'a>( if matching_tbl.is_none() { crate::bail_parse_error!("no such table: {}", normalized_table_name); } - let (tbl_id, tbl) = matching_tbl.unwrap(); + let (tbl_id, tbl) = + matching_tbl.expect("matching_tbl must be Some when is_none() is false"); let normalized_id = normalize_ident(id.as_str()); let col_idx = tbl.columns().iter().position(|c| { c.name @@ -3833,7 +3891,12 @@ pub fn bind_and_rewrite_expr<'a>( let Some(col_idx) = col_idx else { crate::bail_parse_error!("no such column: {}", normalized_id); }; - let col = tbl.columns().get(col_idx).unwrap(); + let col = tbl.columns().get(col_idx).ok_or_else(|| { + crate::LimboError::InternalError(format!( + "column index {} out of bounds", + col_idx + )) + })?; *expr = Expr::Column { database: None, // TODO: support different databases table: tbl_id, @@ -3895,7 +3958,12 @@ pub fn bind_and_rewrite_expr<'a>( )) })?; - let col = table.columns().get(col_idx).unwrap(); + let col = table.columns().get(col_idx).ok_or_else(|| { + crate::LimboError::InternalError(format!( + "column index {} out of bounds", + col_idx + )) + })?; // Check if this is a rowid alias let is_rowid_alias = col.is_rowid_alias(); @@ -4136,9 +4204,12 @@ pub fn get_expr_affinity( Affinity::Blob } } - ast::Expr::Parenthesized(exprs) if exprs.len() == 1 => { - get_expr_affinity(exprs.first().unwrap(), referenced_tables) - } + ast::Expr::Parenthesized(exprs) if exprs.len() == 1 => get_expr_affinity( + exprs + .first() + .expect("exprs.len() == 1 guarantees first() returns Some"), + referenced_tables, + ), ast::Expr::Collate(expr, _) => get_expr_affinity(expr, referenced_tables), // Literals have NO affinity in SQLite! ast::Expr::Literal(_) => Affinity::Blob, // No affinity! @@ -4227,9 +4298,9 @@ pub fn emit_literal( .chunks_exact(2) .map(|pair| { // We assume that sqlite3-parser has already validated that - // the input is valid hex string, thus unwrap is safe. - let hex_byte = std::str::from_utf8(pair).unwrap(); - u8::from_str_radix(hex_byte, 16).unwrap() + // the input is valid hex string, thus expect is safe. + let hex_byte = std::str::from_utf8(pair).expect("parser validated hex string"); + u8::from_str_radix(hex_byte, 16).expect("parser validated hex digit") }) .collect(); program.emit_insn(Insn::Blob { @@ -4386,7 +4457,10 @@ pub(crate) fn emit_returning_results<'a>( } turso_assert!(table_references.joined_tables().len() == 1, "RETURNING is only used with INSERT, UPDATE, or DELETE statements, which target a single table"); - let table = table_references.joined_tables().first().unwrap(); + let table = table_references + .joined_tables() + .first() + .ok_or_else(|| crate::LimboError::InternalError("no joined tables".to_string()))?; resolver.enable_expr_to_reg_cache(); let expr = Expr::RowId { diff --git a/core/translate/index.rs b/core/translate/index.rs index 95bc4a4f31..8cd5f7849c 100644 --- a/core/translate/index.rs +++ b/core/translate/index.rs @@ -167,7 +167,9 @@ pub fn translate_create_index( // 3. table_cursor_id - table we are creating the index on // 4. sorter_cursor_id - sorter // 5. pseudo_cursor_id - pseudo table to store the sorted index values - let sqlite_table = resolver.schema.get_btree_table(SQLITE_TABLEID).unwrap(); + let Some(sqlite_table) = resolver.schema.get_btree_table(SQLITE_TABLEID) else { + crate::bail_parse_error!("sqlite_schema table not found"); + }; let sqlite_schema_cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(sqlite_table.clone())); let table_ref = program.table_reference_counter.next(); @@ -715,12 +717,16 @@ pub fn resolve_index_method_parameters( b.as_bytes() .chunks_exact(2) .map(|pair| { - // We assume that sqlite3-parser has already validated that - // the input is valid hex string, thus unwrap is safe. - let hex_byte = std::str::from_utf8(pair).unwrap(); - u8::from_str_radix(hex_byte, 16).unwrap() + let hex_byte = std::str::from_utf8(pair).map_err(|_| { + crate::LimboError::ParseError( + "invalid UTF-8 in hex string".to_string(), + ) + })?; + u8::from_str_radix(hex_byte, 16).map_err(|_| { + crate::LimboError::ParseError("invalid hex digit".to_string()) + }) }) - .collect(), + .collect::, _>>()?, ), _ => bail_parse_error!("parameters must be constant literals"), }, @@ -802,7 +808,9 @@ pub fn translate_drop_index( let row_id_reg = program.alloc_register(); // We're going to use this cursor to search through sqlite_schema - let sqlite_table = resolver.schema.get_btree_table(SQLITE_TABLEID).unwrap(); + let Some(sqlite_table) = resolver.schema.get_btree_table(SQLITE_TABLEID) else { + crate::bail_parse_error!("sqlite_schema table not found"); + }; let sqlite_schema_cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(sqlite_table.clone())); @@ -904,7 +912,9 @@ pub fn translate_drop_index( p5: 0, }); - let index = maybe_index.unwrap(); + let Some(index) = maybe_index else { + crate::bail_parse_error!("index not found"); + }; if index.index_method.is_some() && !index.is_backing_btree_index() { let cursor_id = program.alloc_cursor_index(None, index)?; program.emit_insn(Insn::IndexMethodDestroy { db: 0, cursor_id }); diff --git a/core/translate/insert.rs b/core/translate/insert.rs index 78867a88d4..7b48c55c80 100644 --- a/core/translate/insert.rs +++ b/core/translate/insert.rs @@ -741,9 +741,11 @@ fn translate_rows_and_open_tables( inserting_multiple_rows: bool, ) -> Result<()> { if inserting_multiple_rows { - let select_result_start_reg = program - .reg_result_cols_start - .unwrap_or(ctx.yield_reg_opt.unwrap() + 1); + let select_result_start_reg = program.reg_result_cols_start.unwrap_or( + ctx.yield_reg_opt + .expect("yield_reg_opt must be set when inserting multiple rows") + + 1, + ); translate_rows_multiple( program, insertion, @@ -1347,7 +1349,7 @@ fn init_source_emission<'a>( } table .get_column_by_name(&column_name) - .unwrap() + .expect("column must exist in table") .1 .affinity() .aff_mask() @@ -2117,7 +2119,10 @@ fn translate_virtual_table_insert( } let (num_values, value) = match &mut body { InsertBody::Select(select, None) => match &mut select.body.select { - OneSelect::Values(values) => (values[0].len(), values.pop().unwrap()), + OneSelect::Values(values) => ( + values[0].len(), + values.pop().expect("values must not be empty for INSERT"), + ), _ => crate::bail_parse_error!("Virtual tables only support VALUES clause in INSERT"), }, InsertBody::DefaultValues => (0, vec![]), @@ -2490,7 +2495,9 @@ fn emit_update_sqlite_sequence( extra_amount: 0, }); - let seq_table = schema.get_btree_table("sqlite_sequence").unwrap(); + let seq_table = schema + .get_btree_table("sqlite_sequence") + .expect("sqlite_sequence table must exist"); let affinity_str = seq_table .columns .iter() @@ -2695,7 +2702,9 @@ pub fn emit_fk_child_insert_checks( // Short-circuit if any NEW component is NULL let fk_ok = program.allocate_label(); for cname in &fk_ref.child_cols { - let (i, col) = child_tbl.get_column(cname).unwrap(); + let (i, col) = child_tbl + .get_column(cname) + .expect("child column must exist"); let src = if col.is_rowid_alias() { new_rowid_reg } else { @@ -2714,7 +2723,9 @@ pub fn emit_fk_child_insert_checks( let pcur = open_read_table(program, &parent_tbl); // first child col carries rowid - let (i_child, col_child) = child_tbl.get_column(&fk_ref.child_cols[0]).unwrap(); + let (i_child, col_child) = child_tbl + .get_column(&fk_ref.child_cols[0]) + .expect("first child column must exist"); let val_reg = if col_child.is_rowid_alias() { new_rowid_reg } else { @@ -2768,7 +2779,9 @@ pub fn emit_fk_child_insert_checks( let probe = { let start = program.alloc_registers(ncols); for (k, cname) in fk_ref.child_cols.iter().enumerate() { - let (i, col) = child_tbl.get_column(cname).unwrap(); + let (i, col) = child_tbl + .get_column(cname) + .expect("child column must exist"); program.emit_insn(Insn::Copy { src_reg: if col.is_rowid_alias() { new_rowid_reg @@ -2793,10 +2806,12 @@ pub fn emit_fk_child_insert_checks( let parent_cols: Vec<&str> = idx.columns.iter().map(|ic| ic.name.as_str()).collect(); - // Build new parent-key image from this same row’s new values, in the index order. + // Build new parent-key image from this same row's new values, in the index order. let parent_new = program.alloc_registers(ncols); for (i, pname) in parent_cols.iter().enumerate() { - let (pos, col) = child_tbl.get_column(pname).unwrap(); + let (pos, col) = child_tbl + .get_column(pname) + .expect("parent column must exist in child table"); program.emit_insn(Insn::Copy { src_reg: if col.is_rowid_alias() { new_rowid_reg diff --git a/core/translate/main_loop.rs b/core/translate/main_loop.rs index f659e1619e..3b7ee2212c 100644 --- a/core/translate/main_loop.rs +++ b/core/translate/main_loop.rs @@ -234,7 +234,10 @@ pub fn init_loop( if let Some(index_cursor_id) = index_cursor_id { program.emit_insn(Insn::OpenRead { cursor_id: index_cursor_id, - root_page: index.as_ref().unwrap().root_page, + root_page: index + .as_ref() + .expect("index must exist when index_cursor_id is Some") + .root_page, db: table.database_id, }); } @@ -250,7 +253,11 @@ pub fn init_loop( if let Some(index_cursor_id) = index_cursor_id { program.emit_insn(Insn::OpenWrite { cursor_id: index_cursor_id, - root_page: index.as_ref().unwrap().root_page.into(), + root_page: index + .as_ref() + .expect("index must exist when index_cursor_id is Some") + .root_page + .into(), db: table.database_id, }); } @@ -291,7 +298,11 @@ pub fn init_loop( .resolve_cursor_id(&CursorKey::table(target_table.internal_id)); program.emit_insn(Insn::OpenWrite { cursor_id: target_table_cursor_id, - root_page: target_table.btree().unwrap().root_page.into(), + root_page: target_table + .btree() + .expect("target table must be a BTree table") + .root_page + .into(), db: table.database_id, }); } @@ -299,7 +310,11 @@ pub fn init_loop( if let Some(index_cursor_id) = index_cursor_id { program.emit_insn(Insn::OpenWrite { cursor_id: index_cursor_id, - root_page: index.as_ref().unwrap().root_page.into(), + root_page: index + .as_ref() + .expect("index must exist when index_cursor_id is Some") + .root_page + .into(), db: table.database_id, }); } @@ -415,10 +430,15 @@ pub fn init_loop( db: table.database_id, }); } - let index_cursor_id = index_cursor_id.unwrap(); + let index_cursor_id = index_cursor_id + .expect("index_cursor_id must be Some for index method query"); program.emit_insn(Insn::OpenRead { cursor_id: index_cursor_id, - root_page: table.op.index().unwrap().root_page, + root_page: table + .op + .index() + .expect("operation must be index method query") + .root_page, db: table.database_id, }); } @@ -435,8 +455,12 @@ pub fn init_loop( let meta = ConditionMetadata { jump_if_condition_is_true: false, jump_target_when_true: jump_target, - jump_target_when_false: t_ctx.label_main_loop_end.unwrap(), - jump_target_when_null: t_ctx.label_main_loop_end.unwrap(), + jump_target_when_false: t_ctx + .label_main_loop_end + .expect("label_main_loop_end must be set"), + jump_target_when_null: t_ctx + .label_main_loop_end + .expect("label_main_loop_end must be set"), }; translate_condition_expr(program, tables, &cond.expr, meta, &t_ctx.resolver)?; program.preassign_label_to_next_insn(jump_target); @@ -476,7 +500,9 @@ pub fn open_loop( // This is used to determine whether to emit actual columns or NULLs for the columns of the right table. if let Some(join_info) = table.join_info.as_ref() { if join_info.outer { - let lj_meta = t_ctx.meta_left_joins[joined_table_index].as_ref().unwrap(); + let lj_meta = t_ctx.meta_left_joins[joined_table_index] + .as_ref() + .expect("left join metadata must exist for outer join"); program.emit_insn(Insn::Integer { value: 0, dest: lj_meta.reg_match_flag, @@ -746,7 +772,9 @@ pub fn open_loop( // for the right table's cursor. if let Some(join_info) = table.join_info.as_ref() { if join_info.outer { - let lj_meta = t_ctx.meta_left_joins[joined_table_index].as_ref().unwrap(); + let lj_meta = t_ctx.meta_left_joins[joined_table_index] + .as_ref() + .expect("left join metadata must exist for outer join"); program.resolve_label(lj_meta.label_match_flag_set_true, program.offset()); program.emit_insn(Insn::Integer { value: 1, @@ -914,7 +942,10 @@ fn emit_loop_source( row_source, registers, .. - } = t_ctx.meta_group_by.as_ref().unwrap(); + } = t_ctx + .meta_group_by + .as_ref() + .expect("meta_group_by must be set when emitting to group by"); let start_reg = registers.reg_group_by_source_cols_start; let mut cur_reg = start_reg; @@ -1026,7 +1057,9 @@ fn emit_loop_source( None }; - let col_start = t_ctx.reg_result_cols_start.unwrap(); + let col_start = t_ctx + .reg_result_cols_start + .expect("reg_result_cols_start must be set"); // Process only non-aggregate columns let non_agg_columns = plan @@ -1048,7 +1081,9 @@ fn emit_loop_source( } if let Some(label) = label_emit_nonagg_only_once { program.resolve_label(label, program.offset()); - let flag = t_ctx.reg_nonagg_emit_once_flag.unwrap(); + let flag = t_ctx + .reg_nonagg_emit_once_flag + .expect("reg_nonagg_emit_once_flag must be set"); program.emit_int(1, flag); } @@ -1072,7 +1107,9 @@ fn emit_loop_source( offset_jump_to, t_ctx.reg_nonagg_emit_once_flag, t_ctx.reg_offset, - t_ctx.reg_result_cols_start.unwrap(), + t_ctx + .reg_result_cols_start + .expect("reg_result_cols_start must be set"), t_ctx.limit_ctx, )?; @@ -1212,7 +1249,8 @@ pub fn close_loop( Operation::IndexMethodQuery(_) => { program.resolve_label(loop_labels.next, program.offset()); program.emit_insn(Insn::Next { - cursor_id: index_cursor_id.unwrap(), + cursor_id: index_cursor_id + .expect("index_cursor_id must be Some for index method query"), pc_if_next: loop_labels.loop_start, }); program.preassign_label_to_next_insn(loop_labels.loop_end); @@ -1223,7 +1261,9 @@ pub fn close_loop( // and emit a row with NULLs for the right table, and then jump back to the next row of the left table. if let Some(join_info) = table.join_info.as_ref() { if join_info.outer { - let lj_meta = t_ctx.meta_left_joins[table_index].as_ref().unwrap(); + let lj_meta = t_ctx.meta_left_joins[table_index] + .as_ref() + .expect("left join metadata must exist for outer join"); // The left join match flag is set to 1 when there is any match on the right table // (e.g. SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a). // If the left join match flag has been set to 1, we jump to the next row on the outer table, @@ -1387,7 +1427,7 @@ fn emit_seek( if affinities.chars().any(|c| c != affinity::SQLITE_AFF_NONE) { program.emit_insn(Insn::Affinity { start_reg, - count: std::num::NonZeroUsize::new(num_regs).unwrap(), + count: std::num::NonZeroUsize::new(num_regs).expect("num_regs must be non-zero"), affinities: seek_def .iter_affinity(&seek_def.start) .map(|affinity| affinity.aff_mask()) @@ -1509,7 +1549,7 @@ fn emit_seek_termination( rowid_reg = Some(program.alloc_register()); program.emit_insn(Insn::RowId { cursor_id: seek_cursor_id, - dest: rowid_reg.unwrap(), + dest: rowid_reg.expect("rowid_reg was just set to Some"), }); affinity = if let Some(table_ref) = tables @@ -1553,39 +1593,39 @@ fn emit_seek_termination( target_pc: loop_end, }), (false, SeekOp::GE { .. }) => program.emit_insn(Insn::Ge { - lhs: rowid_reg.unwrap(), + lhs: rowid_reg.expect("rowid_reg must be Some when is_index is false"), rhs: start_reg, target_pc: loop_end, flags: CmpInsFlags::default() .jump_if_null() - .with_affinity(affinity.unwrap()), + .with_affinity(affinity.expect("affinity must be Some when is_index is false")), collation: program.curr_collation(), }), (false, SeekOp::GT) => program.emit_insn(Insn::Gt { - lhs: rowid_reg.unwrap(), + lhs: rowid_reg.expect("rowid_reg must be Some when is_index is false"), rhs: start_reg, target_pc: loop_end, flags: CmpInsFlags::default() .jump_if_null() - .with_affinity(affinity.unwrap()), + .with_affinity(affinity.expect("affinity must be Some when is_index is false")), collation: program.curr_collation(), }), (false, SeekOp::LE { .. }) => program.emit_insn(Insn::Le { - lhs: rowid_reg.unwrap(), + lhs: rowid_reg.expect("rowid_reg must be Some when is_index is false"), rhs: start_reg, target_pc: loop_end, flags: CmpInsFlags::default() .jump_if_null() - .with_affinity(affinity.unwrap()), + .with_affinity(affinity.expect("affinity must be Some when is_index is false")), collation: program.curr_collation(), }), (false, SeekOp::LT) => program.emit_insn(Insn::Lt { - lhs: rowid_reg.unwrap(), + lhs: rowid_reg.expect("rowid_reg must be Some when is_index is false"), rhs: start_reg, target_pc: loop_end, flags: CmpInsFlags::default() .jump_if_null() - .with_affinity(affinity.unwrap()), + .with_affinity(affinity.expect("affinity must be Some when is_index is false")), collation: program.curr_collation(), }), }; diff --git a/core/translate/optimizer/constraints.rs b/core/translate/optimizer/constraints.rs index 0949e104e1..b2c6c810e8 100644 --- a/core/translate/optimizer/constraints.rs +++ b/core/translate/optimizer/constraints.rs @@ -288,18 +288,19 @@ pub fn constraints_from_where_clause( // A rowid alias column must exist for the 'rowid' keyword to be considered a valid reference. // This should be a parse error at an earlier stage of the query compilation, but nevertheless, // we check it here. - if *table == table_reference.internal_id && rowid_alias_column.is_some() { - let table_column = - &table_reference.table.columns()[rowid_alias_column.unwrap()]; - cs.constraints.push(Constraint { - where_clause_pos: (i, BinaryExprSide::Rhs), - operator, - table_col_pos: rowid_alias_column, - expr: None, - lhs_mask: table_mask_from_expr(rhs, table_references, subqueries)?, - selectivity: estimate_selectivity(Some(table_column), operator), - usable: true, - }); + if *table == table_reference.internal_id { + if let Some(rowid_col_pos) = rowid_alias_column { + let table_column = &table_reference.table.columns()[rowid_col_pos]; + cs.constraints.push(Constraint { + where_clause_pos: (i, BinaryExprSide::Rhs), + operator, + table_col_pos: Some(rowid_col_pos), + expr: None, + lhs_mask: table_mask_from_expr(rhs, table_references, subqueries)?, + selectivity: estimate_selectivity(Some(table_column), operator), + usable: true, + }); + } } } _ if expression_matches_table( @@ -337,18 +338,19 @@ pub fn constraints_from_where_clause( } } ast::Expr::RowId { table, .. } => { - if *table == table_reference.internal_id && rowid_alias_column.is_some() { - let table_column = - &table_reference.table.columns()[rowid_alias_column.unwrap()]; - cs.constraints.push(Constraint { - where_clause_pos: (i, BinaryExprSide::Lhs), - operator: opposite_cmp_op(operator), - table_col_pos: rowid_alias_column, - expr: None, - lhs_mask: table_mask_from_expr(lhs, table_references, subqueries)?, - selectivity: estimate_selectivity(Some(table_column), operator), - usable: true, - }); + if *table == table_reference.internal_id { + if let Some(rowid_col_pos) = rowid_alias_column { + let table_column = &table_reference.table.columns()[rowid_col_pos]; + cs.constraints.push(Constraint { + where_clause_pos: (i, BinaryExprSide::Lhs), + operator: opposite_cmp_op(operator), + table_col_pos: Some(rowid_col_pos), + expr: None, + lhs_mask: table_mask_from_expr(lhs, table_references, subqueries)?, + selectivity: estimate_selectivity(Some(table_column), operator), + usable: true, + }); + } } } _ if expression_matches_table( @@ -413,7 +415,7 @@ pub fn constraints_from_where_clause( None } }) - .unwrap(); + .expect("rowid candidate must exist in candidates list"); rowid_candidate.refs.push(ConstraintRef { constraint_vec_pos: i, index_col_pos: 0, @@ -563,7 +565,10 @@ pub fn usable_constraints_for_join_order<'a>( ) -> Vec { debug_assert!(refs.is_sorted_by_key(|x| x.index_col_pos)); - let table_idx = join_order.last().unwrap().original_idx; + let table_idx = join_order + .last() + .expect("join_order must not be empty") + .original_idx; let lhs_mask = TableMask::from_table_number_iter( join_order .iter() @@ -584,23 +589,32 @@ pub fn usable_constraints_for_join_order<'a>( } if Some(cref.index_col_pos) == usable.last().map(|x| x.index_col_pos) { // Two constraints on the same index column can be combined into a single range constraint. - assert_eq!(cref.sort_order, usable.last().unwrap().sort_order); - assert_eq!(cref.index_col_pos, usable.last().unwrap().index_col_pos); + let last_usable = usable + .last() + .expect("usable must not be empty when matching last index_col_pos"); + assert_eq!(cref.sort_order, last_usable.sort_order); + assert_eq!(cref.index_col_pos, last_usable.index_col_pos); assert_eq!( constraints[cref.constraint_vec_pos].table_col_pos, - usable.last().unwrap().table_col_pos + last_usable.table_col_pos ); // if we already have eq constraint - we must not add anything to it // otherwise, we can incorrectly consume filters which will not be used in the access path - if usable.last().unwrap().eq.is_some() { + if last_usable.eq.is_some() { continue; } match constraints[cref.constraint_vec_pos].operator { ast::Operator::Greater | ast::Operator::GreaterEquals => { - usable.last_mut().unwrap().lower_bound = Some(cref.constraint_vec_pos); + usable + .last_mut() + .expect("usable must not be empty") + .lower_bound = Some(cref.constraint_vec_pos); } ast::Operator::Less | ast::Operator::LessEquals => { - usable.last_mut().unwrap().upper_bound = Some(cref.constraint_vec_pos); + usable + .last_mut() + .expect("usable must not be empty") + .upper_bound = Some(cref.constraint_vec_pos); } _ => {} } @@ -677,7 +691,10 @@ pub fn convert_to_vtab_constraint( constraints: &[Constraint], join_order: &[JoinOrderMember], ) -> Vec { - let table_idx = join_order.last().unwrap().original_idx; + let table_idx = join_order + .last() + .expect("join_order must not be empty") + .original_idx; let lhs_mask = TableMask::from_table_number_iter( join_order .iter() diff --git a/core/translate/optimizer/mod.rs b/core/translate/optimizer/mod.rs index d99ac603e3..525f4c68e0 100644 --- a/core/translate/optimizer/mod.rs +++ b/core/translate/optimizer/mod.rs @@ -303,11 +303,10 @@ fn add_ephemeral_table_to_update_plan( .is_some_and(|join_info| join_info.outer), }) .collect(); - let rowid_internal_id = table_references_ephemeral_select - .joined_tables() - .first() - .unwrap() - .internal_id; + let Some(first_table) = table_references_ephemeral_select.joined_tables().first() else { + crate::bail_parse_error!("ephemeral select must have at least one joined table"); + }; + let rowid_internal_id = first_table.internal_id; let ephemeral_plan = SelectPlan { table_references: table_references_ephemeral_select, @@ -652,7 +651,10 @@ fn optimize_table_access( .iter() .any(|c| where_clause[c.where_clause_pos.0].from_outer_join.is_none()) { - t.join_info.as_mut().unwrap().outer = false; + let Some(join_info) = t.join_info.as_mut() else { + crate::bail_parse_error!("join_info must exist for table with outer join"); + }; + join_info.outer = false; for term in where_clause.iter_mut() { if let Some(from_outer_join) = term.from_outer_join { if from_outer_join == t.internal_id { @@ -811,7 +813,7 @@ fn optimize_table_access( let ephemeral_index = ephemeral_index_build( &table_references.joined_tables_mut()[table_idx], &usable_constraint_refs, - ); + )?; let ephemeral_index = Arc::new(ephemeral_index); table_references.joined_tables_mut()[table_idx].op = Operation::Search(Search::Seek { @@ -1328,19 +1330,28 @@ impl Optimizable for ast::Expr { fn ephemeral_index_build( table_reference: &JoinedTable, constraint_refs: &[RangeConstraintRef], -) -> Index { +) -> Result { let mut ephemeral_columns: Vec = table_reference .columns() .iter() .enumerate() - .map(|(i, c)| IndexColumn { - name: c.name.clone().unwrap(), - order: SortOrder::Asc, - pos_in_table: i, - collation: c.collation_opt(), - default: c.default.clone(), - expr: None, + .map(|(i, c)| { + let Some(name) = c.name.clone() else { + return Err(crate::LimboError::ParseError( + "column must have a name".to_string(), + )); + }; + Ok(IndexColumn { + name, + order: SortOrder::Asc, + pos_in_table: i, + collation: c.collation_opt(), + default: c.default.clone(), + expr: None, + }) }) + .collect::>>()? + .into_iter() // only include columns that are used in the query .filter(|c| table_reference.column_is_used(c.pos_in_table)) .collect(); @@ -1380,7 +1391,7 @@ fn ephemeral_index_build( index_method: None, }; - ephemeral_index + Ok(ephemeral_index) } /// Build a [SeekDef] for a given list of [Constraint]s @@ -1435,7 +1446,9 @@ fn build_seek_def( mut key: Vec, ) -> Result { assert!(!key.is_empty()); - let last = key.last().unwrap(); + let Some(last) = key.last() else { + crate::bail_parse_error!("key must not be empty"); + }; // if we searching for exact key - emit definition immediately with prefix as a full key if last.eq.is_some() { @@ -1461,7 +1474,9 @@ fn build_seek_def( assert!(last.lower_bound.is_some() || last.upper_bound.is_some()); // pop last key as we will do some form of range search - let last = key.pop().unwrap(); + let Some(last) = key.pop() else { + crate::bail_parse_error!("key must not be empty for pop"); + }; // after that all key components must be equality constraints debug_assert!(key.iter().all(|k| k.eq.is_some())); diff --git a/core/translate/plan.rs b/core/translate/plan.rs index 7fc9fdca4a..73c394563b 100644 --- a/core/translate/plan.rs +++ b/core/translate/plan.rs @@ -41,18 +41,26 @@ impl ResultSetColumn { } match &self.expr { ast::Expr::Column { table, column, .. } => { - let joined_table_ref = tables.find_joined_table_by_internal_id(*table).unwrap(); + let joined_table_ref = tables + .find_joined_table_by_internal_id(*table) + .expect("table internal ID is valid"); if let Operation::IndexMethodQuery(module) = &joined_table_ref.op { if module.covered_columns.contains_key(column) { return None; } } let table_ref = &joined_table_ref.table; - table_ref.get_column_at(*column).unwrap().name.as_deref() + table_ref + .get_column_at(*column) + .expect("column index is valid") + .name + .as_deref() } ast::Expr::RowId { table, .. } => { // If there is a rowid alias column, use its name - let (_, table_ref) = tables.find_table_by_internal_id(*table).unwrap(); + let (_, table_ref) = tables + .find_table_by_internal_id(*table) + .expect("table internal ID is valid"); if let Table::BTree(table) = &table_ref { if let Some(rowid_alias_column) = table.get_rowid_alias_column() { if let Some(name) = &rowid_alias_column.1.name { @@ -430,11 +438,18 @@ impl SelectPlan { { return false; } - let table_ref = self.table_references.joined_tables().first().unwrap(); + let table_ref = self + .table_references + .joined_tables() + .first() + .expect("checked table count above"); if !matches!(table_ref.table, crate::schema::Table::BTree(..)) { return false; } - let agg = self.aggregates.first().unwrap(); + let agg = self + .aggregates + .first() + .expect("checked aggregates.len() == 1 above"); if !matches!(agg.func, AggFunc::Count0) { return false; } @@ -456,7 +471,11 @@ impl SelectPlan { over_clause: None, }, }; - let result_col_expr = &self.result_columns.first().unwrap().expr; + let result_col_expr = &self + .result_columns + .first() + .expect("checked result_columns.len() == 1 above") + .expr; if *result_col_expr != count && *result_col_expr != count_star { return false; } @@ -1313,7 +1332,11 @@ impl<'a> Iterator for SeekDefKeyIterator<'a, SeekKeyComponent<&'a ast::Expr>> { fn next(&mut self) -> Option { let result = if self.pos < self.seek_def.prefix.len() { Some(SeekKeyComponent::Expr( - &self.seek_def.prefix[self.pos].eq.as_ref().unwrap().1, + &self.seek_def.prefix[self.pos] + .eq + .as_ref() + .expect("prefix entries have eq") + .1, )) } else if self.pos == self.seek_def.prefix.len() { match &self.seek_key.last_component { @@ -1333,7 +1356,13 @@ impl<'a> Iterator for SeekDefKeyIterator<'a, Affinity> { fn next(&mut self) -> Option { let result = if self.pos < self.seek_def.prefix.len() { - Some(self.seek_def.prefix[self.pos].eq.as_ref().unwrap().2) + Some( + self.seek_def.prefix[self.pos] + .eq + .as_ref() + .expect("prefix entries have eq") + .2, + ) } else if self.pos == self.seek_def.prefix.len() { match &self.seek_key.last_component { SeekKeyComponent::Expr(..) => Some(self.seek_key.affinity), @@ -1662,9 +1691,8 @@ impl NonFromClauseSubquery { let SubqueryState::Unevaluated { plan } = &self.state else { crate::bail_parse_error!("subquery has already been evaluated"); }; - let used_outer_refs = plan - .as_ref() - .unwrap() + let plan_ref = plan.as_ref().expect("checked unevaluated state above"); + let used_outer_refs = plan_ref .table_references .outer_query_refs() .iter() @@ -1679,7 +1707,7 @@ impl NonFromClauseSubquery { }; eval_at = eval_at.max(EvalAt::Loop(loop_idx)); } - for subquery in plan.as_ref().unwrap().non_from_clause_subqueries.iter() { + for subquery in plan_ref.non_from_clause_subqueries.iter() { let eval_at_inner = subquery.get_eval_at(join_order)?; eval_at = eval_at.max(eval_at_inner); } @@ -1691,7 +1719,7 @@ impl NonFromClauseSubquery { pub fn consume_plan(&mut self, evaluated_at: EvalAt) -> Box { match &mut self.state { SubqueryState::Unevaluated { plan } => { - let plan = plan.take().unwrap(); + let plan = plan.take().expect("plan exists in unevaluated state"); self.state = SubqueryState::Evaluated { evaluated_at }; plan } @@ -1729,7 +1757,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64( std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) - .unwrap() + .expect("system time is after UNIX epoch") .as_secs(), ); @@ -1762,7 +1790,7 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64( std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) - .unwrap() + .expect("system time is after UNIX epoch") .as_secs(), ); diff --git a/core/translate/planner.rs b/core/translate/planner.rs index b8f31db2eb..f776d8c29f 100644 --- a/core/translate/planner.rs +++ b/core/translate/planner.rs @@ -726,7 +726,8 @@ pub fn parse_from( } } - let from_owned = std::mem::take(&mut from).unwrap(); + let from_owned = + std::mem::take(&mut from).expect("from clause must be Some for SELECT with FROM"); let select_owned = from_owned.select; let joins_owned = from_owned.joins; parse_from_clause_table( @@ -939,7 +940,7 @@ pub fn table_mask_from_expr( }; let used_outer_query_refs = plan .as_ref() - .unwrap() + .expect("unevaluated subquery must have a plan") .table_references .outer_query_refs() .iter() @@ -992,7 +993,7 @@ pub fn determine_where_to_eval_expr( SubqueryState::Unevaluated { plan } => { let used_outer_refs = plan .as_ref() - .unwrap() + .expect("unevaluated subquery must have a plan") .table_references .outer_query_refs() .iter() @@ -1066,7 +1067,10 @@ fn parse_join( // this is called once for each join, so we only need to check the rightmost table // against all previous tables for duplicates - let rightmost_table = table_references.joined_tables().last().unwrap(); + let rightmost_table = table_references + .joined_tables() + .last() + .expect("joined_tables must not be empty when checking for duplicates"); let has_duplicate = table_references .joined_tables() .iter() @@ -1185,8 +1189,10 @@ fn parse_join( distinct_name.as_str() ); } - let (left_table_idx, left_table_id, left_col_idx, left_col) = left_col.unwrap(); - let (right_col_idx, right_col) = right_col.unwrap(); + let (left_table_idx, left_table_id, left_col_idx, left_col) = + left_col.expect("left_col must be Some when not bailing"); + let (right_col_idx, right_col) = + right_col.expect("right_col must be Some when not bailing"); let expr = Expr::Binary( Box::new(Expr::Column { database: None, @@ -1206,12 +1212,12 @@ fn parse_join( let left_table: &mut JoinedTable = table_references .joined_tables_mut() .get_mut(left_table_idx) - .unwrap(); + .expect("left_table_idx must be valid"); left_table.mark_column_used(left_col_idx); let right_table: &mut JoinedTable = table_references .joined_tables_mut() .get_mut(cur_table_idx) - .unwrap(); + .expect("cur_table_idx must be valid"); right_table.mark_column_used(right_col_idx); out_where_clause.push(WhereTerm { expr, @@ -1233,7 +1239,7 @@ fn parse_join( let rightmost_table = table_references .joined_tables_mut() .get_mut(last_idx) - .unwrap(); + .expect("last_idx must be valid since joined_tables().len() >= 2"); rightmost_table.join_info = Some(JoinInfo { outer, using }); Ok(()) diff --git a/core/translate/schema.rs b/core/translate/schema.rs index 38a5c2c619..5e593c8070 100644 --- a/core/translate/schema.rs +++ b/core/translate/schema.rs @@ -144,7 +144,9 @@ pub fn translate_create_table( } } - let schema_master_table = resolver.schema.get_btree_table(SQLITE_TABLEID).unwrap(); + let Some(schema_master_table) = resolver.schema.get_btree_table(SQLITE_TABLEID) else { + crate::bail_parse_error!("sqlite_schema table not found"); + }; let sqlite_schema_cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(schema_master_table.clone())); program.emit_insn(Insn::OpenWrite { @@ -233,7 +235,9 @@ pub fn translate_create_table( } } - let table = resolver.schema.get_btree_table(SQLITE_TABLEID).unwrap(); + let Some(table) = resolver.schema.get_btree_table(SQLITE_TABLEID) else { + crate::bail_parse_error!("sqlite_schema table not found"); + }; let sqlite_schema_cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(table.clone())); program.emit_insn(Insn::OpenWrite { cursor_id: sqlite_schema_cursor_id, @@ -424,7 +428,9 @@ fn collect_autoindexes( // include UNIQUE singles, include PK single only if not rowid alias for us in table.unique_sets.iter().filter(|us| us.columns.len() == 1) { - let (col_name, _sort) = us.columns.first().unwrap(); + let Some((col_name, _sort)) = us.columns.first() else { + crate::bail_parse_error!("unique set must have at least one column"); + }; let Some((_pos, col)) = table.get_column(col_name) else { bail_parse_error!("Column {col_name} not found in table {}", table.name); }; @@ -570,7 +576,9 @@ pub fn translate_create_virtual_table( table_name: table_name_reg, args_reg, }); - let table = resolver.schema.get_btree_table(SQLITE_TABLEID).unwrap(); + let Some(table) = resolver.schema.get_btree_table(SQLITE_TABLEID) else { + crate::bail_parse_error!("sqlite_schema table not found"); + }; let sqlite_schema_cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(table.clone())); program.emit_insn(Insn::OpenWrite { cursor_id: sqlite_schema_cursor_id, @@ -653,7 +661,9 @@ pub fn translate_drop_table( bail_parse_error!("table {} may not be dropped", tbl_name.name.as_str()); } - let table = table.unwrap(); // safe since we just checked for None + let Some(table) = table else { + crate::bail_parse_error!("table not found"); + }; // Check if this is a materialized view - if so, refuse to drop it with DROP TABLE if resolver.schema.is_materialized_view(tbl_name.name.as_str()) { @@ -683,7 +693,9 @@ pub fn translate_drop_table( program.mark_last_insn_constant(); let row_id_reg = program.alloc_register(); // r5 - let schema_table = resolver.schema.get_btree_table(SQLITE_TABLEID).unwrap(); + let Some(schema_table) = resolver.schema.get_btree_table(SQLITE_TABLEID) else { + crate::bail_parse_error!("sqlite_schema table not found"); + }; let sqlite_schema_cursor_id_0 = program.alloc_cursor_id( // cursor 0 CursorType::BTreeTable(schema_table.clone()), diff --git a/core/translate/select.rs b/core/translate/select.rs index 7c024cb616..491fde0bdb 100644 --- a/core/translate/select.rs +++ b/core/translate/select.rs @@ -337,10 +337,10 @@ fn prepare_one_select_plan( .iter_mut() .find(|t| t.identifier == name_normalized); - if referenced_table.is_none() { - crate::bail_parse_error!("no such table: {}", name.as_str()); - } - let table = referenced_table.unwrap(); + let table = match referenced_table { + Some(t) => t, + None => crate::bail_parse_error!("no such table: {}", name.as_str()), + }; let num_columns = table.columns().len(); for idx in 0..num_columns { let column = &table.columns()[idx]; @@ -768,11 +768,11 @@ pub fn emit_simple_count( _t_ctx: &mut TranslateCtx, plan: &SelectPlan, ) -> Result<()> { - let cursors = plan - .joined_tables() - .first() - .unwrap() - .resolve_cursors(program, OperationMode::SELECT)?; + let first_table = match plan.joined_tables().first() { + Some(t) => t, + None => crate::bail_parse_error!("no joined tables in plan"), + }; + let cursors = first_table.resolve_cursors(program, OperationMode::SELECT)?; let cursor_id = { match cursors { diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 8101660b4d..ee73b3322c 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -24,7 +24,10 @@ use crate::vdbe::affinity::{apply_numeric_affinity, try_for_float, Affinity, Par use crate::vdbe::insn::InsertFlags; use crate::vdbe::value::ComparisonOp; use crate::vdbe::{registers_to_ref_values, EndStatement, StepResult, TxnCleanup}; -use crate::vector::{vector32_sparse, vector_concat, vector_distance_jaccard, vector_slice}; +use crate::vector::{ + vector32, vector32_sparse, vector64, vector_concat, vector_distance_cos, + vector_distance_jaccard, vector_distance_l2, vector_extract, vector_slice, +}; use crate::{ error::{ LimboError, SQLITE_CONSTRAINT, SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY, @@ -66,7 +69,6 @@ use crate::{ builder::CursorType, insn::{IdxInsertFlags, Insn}, }, - vector::{vector32, vector64, vector_distance_cos, vector_distance_l2, vector_extract}, }; use crate::{info, turso_assert, OpenFlags, Row, TransactionState, ValueRef}; @@ -5526,47 +5528,46 @@ pub fn op_function( } }, crate::function::Func::Vector(vector_func) => { - let values = - registers_to_ref_values(&state.registers[*start_reg..*start_reg + arg_count]); + let args = &state.registers[*start_reg..*start_reg + arg_count]; match vector_func { VectorFunc::Vector => { - let result = vector32(values)?; + let result = vector32(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::Vector32 => { - let result = vector32(values)?; + let result = vector32(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::Vector32Sparse => { - let result = vector32_sparse(values)?; + let result = vector32_sparse(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::Vector64 => { - let result = vector64(values)?; + let result = vector64(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::VectorExtract => { - let result = vector_extract(values)?; + let result = vector_extract(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::VectorDistanceCos => { - let result = vector_distance_cos(values)?; + let result = vector_distance_cos(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::VectorDistanceL2 => { - let result = vector_distance_l2(values)?; + let result = vector_distance_l2(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::VectorDistanceJaccard => { - let result = vector_distance_jaccard(values)?; + let result = vector_distance_jaccard(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::VectorConcat => { - let result = vector_concat(values)?; + let result = vector_concat(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::VectorSlice => { - let result = vector_slice(values)?; + let result = vector_slice(args)?; state.registers[*dest] = Register::Value(result) } } diff --git a/core/vector/mod.rs b/core/vector/mod.rs index 31038564b1..17f2be1316 100644 --- a/core/vector/mod.rs +++ b/core/vector/mod.rs @@ -1,6 +1,7 @@ use crate::types::AsValueRef; use crate::types::Value; use crate::types::ValueType; +use crate::vdbe::Register; use crate::LimboError; use crate::Result; use crate::ValueRef; @@ -33,75 +34,50 @@ pub fn parse_vector<'a>( } } -pub fn vector32(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector32(args: &[Register]) -> Result { if args.len() != 1 { return Err(LimboError::ConversionError( "vector32 requires exactly one argument".to_string(), )); } - let value = args.next().unwrap(); - let vector = parse_vector(&value, Some(VectorType::Float32Dense))?; + let value = args[0].get_value(); + let vector = parse_vector(value, Some(VectorType::Float32Dense))?; let vector = operations::convert::vector_convert(vector, VectorType::Float32Dense)?; Ok(operations::serialize::vector_serialize(vector)) } -pub fn vector32_sparse(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector32_sparse(args: &[Register]) -> Result { if args.len() != 1 { return Err(LimboError::ConversionError( "vector32_sparse requires exactly one argument".to_string(), )); } - let value = args.next().unwrap(); - let vector = parse_vector(&value, Some(VectorType::Float32Sparse))?; + let value = args[0].get_value(); + let vector = parse_vector(value, Some(VectorType::Float32Sparse))?; let vector = operations::convert::vector_convert(vector, VectorType::Float32Sparse)?; Ok(operations::serialize::vector_serialize(vector)) } -pub fn vector64(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector64(args: &[Register]) -> Result { if args.len() != 1 { return Err(LimboError::ConversionError( "vector64 requires exactly one argument".to_string(), )); } - let value = args.next().unwrap(); - let vector = parse_vector(&value, Some(VectorType::Float64Dense))?; + let value = args[0].get_value(); + let vector = parse_vector(value, Some(VectorType::Float64Dense))?; let vector = operations::convert::vector_convert(vector, VectorType::Float64Dense)?; Ok(operations::serialize::vector_serialize(vector)) } -pub fn vector_extract(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector_extract(args: &[Register]) -> Result { if args.len() != 1 { return Err(LimboError::ConversionError( "vector_extract requires exactly one argument".to_string(), )); } - let value = args.next().unwrap(); - let value = value.as_value_ref(); + let value = args[0].get_value().as_value_ref(); let blob = match value { ValueRef::Blob(b) => b, _ => { @@ -119,110 +95,77 @@ where Ok(Value::build_text(operations::text::vector_to_text(&vector))) } -pub fn vector_distance_cos(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector_distance_cos(args: &[Register]) -> Result { if args.len() != 2 { return Err(LimboError::ConversionError( "vector_distance_cos requires exactly two arguments".to_string(), )); } - let value_0 = args.next().unwrap(); - let value_1 = args.next().unwrap(); - let x = parse_vector(&value_0, None)?; - let y = parse_vector(&value_1, None)?; + let value_0 = args[0].get_value(); + let value_1 = args[1].get_value(); + let x = parse_vector(value_0, None)?; + let y = parse_vector(value_1, None)?; let dist = operations::distance_cos::vector_distance_cos(&x, &y)?; Ok(Value::Float(dist)) } -pub fn vector_distance_l2(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector_distance_l2(args: &[Register]) -> Result { if args.len() != 2 { return Err(LimboError::ConversionError( "distance_l2 requires exactly two arguments".to_string(), )); } - let value_0 = args.next().unwrap(); - let value_1 = args.next().unwrap(); - let x = parse_vector(&value_0, None)?; - let y = parse_vector(&value_1, None)?; + let value_0 = args[0].get_value(); + let value_1 = args[1].get_value(); + let x = parse_vector(value_0, None)?; + let y = parse_vector(value_1, None)?; let dist = operations::distance_l2::vector_distance_l2(&x, &y)?; Ok(Value::Float(dist)) } -pub fn vector_distance_jaccard(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector_distance_jaccard(args: &[Register]) -> Result { if args.len() != 2 { return Err(LimboError::ConversionError( "distance_jaccard requires exactly two arguments".to_string(), )); } - let value_0 = args.next().unwrap(); - let value_1 = args.next().unwrap(); - let x = parse_vector(&value_0, None)?; - let y = parse_vector(&value_1, None)?; + let value_0 = args[0].get_value(); + let value_1 = args[1].get_value(); + let x = parse_vector(value_0, None)?; + let y = parse_vector(value_1, None)?; let dist = operations::jaccard::vector_distance_jaccard(&x, &y)?; Ok(Value::Float(dist)) } -pub fn vector_concat(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector_concat(args: &[Register]) -> Result { if args.len() != 2 { return Err(LimboError::InvalidArgument( "concat requires exactly two arguments".into(), )); } - let value_0 = args.next().unwrap(); - let value_1 = args.next().unwrap(); - let x = parse_vector(&value_0, None)?; - let y = parse_vector(&value_1, None)?; + let value_0 = args[0].get_value(); + let value_1 = args[1].get_value(); + let x = parse_vector(value_0, None)?; + let y = parse_vector(value_1, None)?; let vector = operations::concat::vector_concat(&x, &y)?; Ok(operations::serialize::vector_serialize(vector)) } -pub fn vector_slice(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector_slice(args: &[Register]) -> Result { if args.len() != 3 { return Err(LimboError::InvalidArgument( "vector_slice requires exactly three arguments".into(), )); } - let value_0 = args.next().unwrap(); - let value_1 = args.next().unwrap(); - let value_1 = value_1.as_value_ref(); - - let value_2 = args.next().unwrap(); - let value_2 = value_2.as_value_ref(); + let value_0 = args[0].get_value(); + let value_1 = args[1].get_value().as_value_ref(); + let value_2 = args[2].get_value().as_value_ref(); - let vector = parse_vector(&value_0, None)?; + let vector = parse_vector(value_0, None)?; let start_index = value_1 .as_int() diff --git a/core/vector/vector_types.rs b/core/vector/vector_types.rs index 8b70fdbda2..8186119a94 100644 --- a/core/vector/vector_types.rs +++ b/core/vector/vector_types.rs @@ -128,7 +128,9 @@ impl<'a> Vector<'a> { ) -> Result { let owned_slice = owned.as_deref(); let refer_slice = refer.as_ref().map(|&x| x); - let data = owned_slice.unwrap_or_else(|| refer_slice.unwrap()); + let data = owned_slice.or(refer_slice).ok_or_else(|| { + LimboError::InternalError("Vector must have either owned or refer data".to_string()) + })?; match vector_type { VectorType::Float32Dense => { if data.len() % 4 != 0 { @@ -167,7 +169,12 @@ impl<'a> Vector<'a> { } let original_len = data.len(); let dims_bytes = &data[original_len - 4..]; - let dims = u32::from_le_bytes(dims_bytes.try_into().unwrap()) as usize; + let dims = u32::from_le_bytes([ + dims_bytes[0], + dims_bytes[1], + dims_bytes[2], + dims_bytes[3], + ]) as usize; let owned = owned.map(|mut x| { x.truncate(original_len - 4); x @@ -187,17 +194,25 @@ impl<'a> Vector<'a> { pub fn bin_len(&self) -> usize { let owned = self.owned.as_ref().map(|x| x.len()); let refer = self.refer.as_ref().map(|x| x.len()); - owned.unwrap_or_else(|| refer.unwrap()) + owned + .or(refer) + .expect("Vector invariant: exactly one of owned or refer must be Some") } pub fn bin_data(&'a self) -> &'a [u8] { let owned = self.owned.as_deref(); let refer = self.refer.as_ref().map(|&x| x); - owned.unwrap_or_else(|| refer.unwrap()) + owned + .or(refer) + .expect("Vector invariant: exactly one of owned or refer must be Some") } pub fn bin_eject(self) -> Vec { - self.owned.unwrap_or_else(|| self.refer.unwrap().to_vec()) + self.owned.unwrap_or_else(|| { + self.refer + .expect("Vector invariant: exactly one of owned or refer must be Some") + .to_vec() + }) } /// # Safety diff --git a/parser/src/lib.rs b/parser/src/lib.rs index 14f16356c9..56b5caa2b9 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -1,3 +1,5 @@ +#![allow(clippy::unwrap_used)] + pub mod ast; pub mod error; pub mod lexer;