diff --git a/core/lib.rs b/core/lib.rs index f0100bbd66..ba39ff5bc4 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -41,9 +41,7 @@ pub mod numeric; mod numeric; use crate::index_method::IndexMethod; -use crate::storage::checksum::CHECKSUM_REQUIRED_RESERVED_BYTES; use crate::storage::encryption::AtomicCipherMode; -use crate::storage::pager::{AutoVacuumMode, HeaderRef}; use crate::translate::display::PlanContext; use crate::translate::pragma::TURSO_CDC_DEFAULT_TABLE_NAME; #[cfg(all(feature = "fs", feature = "conn_raw_api"))] @@ -90,7 +88,7 @@ pub use storage::database::IOContext; pub use storage::encryption::{CipherMode, EncryptionContext, EncryptionKey}; use storage::page_cache::PageCache; use storage::pager::{AtomicDbState, DbState}; -use storage::sqlite3_ondisk::PageSize; +use storage::sqlite3_ondisk::{CacheSize, PageSize}; pub use storage::{ buffer_pool::BufferPool, database::DatabaseStorage, @@ -244,6 +242,7 @@ pub struct Database { open_flags: Cell, builtin_syms: RwLock, opts: DatabaseOpts, + pending_encryption_opts: Mutex>, n_connections: AtomicUsize, } @@ -498,6 +497,7 @@ impl Database { db_state: Arc::new(AtomicDbState::new(db_state)), init_lock: Arc::new(Mutex::new(())), opts, + pending_encryption_opts: Mutex::new(encryption_opts), buffer_pool: BufferPool::begin_init(&io, arena_size), n_connections: AtomicUsize::new(0), }); @@ -505,49 +505,6 @@ impl Database { db.register_global_builtin_extensions() .expect("unable to register global extensions"); - // Check: https://github.com/tursodatabase/turso/pull/1761#discussion_r2154013123 - if db_state.is_initialized() { - // parse schema - let conn = db.connect()?; - - let syms = conn.syms.read(); - let pager = conn.pager.load().clone(); - - if let Some(encryption_opts) = encryption_opts { - conn.pragma_update("cipher", format!("'{}'", encryption_opts.cipher))?; - conn.pragma_update("hexkey", format!("'{}'", encryption_opts.hexkey))?; - // Clear page cache so the header page can be reread from disk and decrypted using the encryption context. - pager.clear_page_cache(false); - } - db.with_schema_mut(|schema| { - let header_schema_cookie = pager - .io - .block(|| pager.with_header(|header| header.schema_cookie.get()))?; - schema.schema_version = header_schema_cookie; - let result = schema - .make_from_btree(None, pager.clone(), &syms) - .inspect_err(|_| pager.end_read_tx()); - match result { - Err(LimboError::ExtensionError(e)) => { - // this means that a vtab exists and we no longer have the module loaded. we print - // a warning to the user to load the module - eprintln!("Warning: {e}"); - } - Err(e) => return Err(e), - _ => {} - } - - if db.mvcc_enabled() && !schema.indexes.is_empty() { - return Err(LimboError::ParseError( - "Database contains indexes which are not supported when MVCC is enabled." - .to_string(), - )); - } - - Ok(()) - })?; - } - if opts.enable_mvcc { let mv_store = db.mv_store.as_ref().unwrap(); let mvcc_bootstrap_conn = db.connect_mvcc_bootstrap()?; @@ -574,50 +531,11 @@ impl Database { let pager = self.init_pager(None)?; pager.enable_encryption(self.opts.enable_encryption); let pager = Arc::new(pager); + pager.allow_autovacuum(self.opts.enable_autovacuum); - if self.db_state.get().is_initialized() { - let header_ref = pager.io.block(|| HeaderRef::from_pager(&pager))?; - - let header = header_ref.borrow(); - - let mode = if header.vacuum_mode_largest_root_page.get() > 0 { - if header.incremental_vacuum_enabled.get() > 0 { - AutoVacuumMode::Incremental - } else { - AutoVacuumMode::Full - } - } else { - AutoVacuumMode::None - }; - - // Force autovacuum to None if the experimental flag is not enabled - let final_mode = if !self.opts.enable_autovacuum { - if mode != AutoVacuumMode::None { - tracing::warn!( - "Database has autovacuum enabled but --experimental-autovacuum flag is not set. Forcing autovacuum to None." - ); - } - AutoVacuumMode::None - } else { - mode - }; - - pager.set_auto_vacuum_mode(final_mode); - - tracing::debug!( - "Opened existing database. Detected auto_vacuum_mode from header: {:?}, final mode: {:?}", - mode, - final_mode - ); - } - - let page_size = pager.get_page_size_unchecked(); + let initial_page_size_raw = pager.get_page_size().map(|ps| ps.get_raw()).unwrap_or(0); - let default_cache_size = pager - .io - .block(|| pager.with_header(|header| header.default_page_cache_size)) - .unwrap_or_default() - .get(); + let default_cache_size = CacheSize::default().get(); let conn = Arc::new(Connection { db: self.clone(), pager: ArcSwap::new(pager), @@ -631,7 +549,7 @@ impl Database { syms: RwLock::new(SymbolTable::new()), _shared_cache: false, cache_size: AtomicI32::new(default_cache_size), - page_size: AtomicU16::new(page_size.get_raw()), + page_size: AtomicU16::new(initial_page_size_raw), wal_auto_checkpoint_disabled: AtomicBool::new(false), capture_data_changes: RwLock::new(CaptureDataChangesMode::Off), closed: AtomicBool::new(false), @@ -650,6 +568,9 @@ impl Database { fk_pragma: AtomicBool::new(false), fk_deferred_violations: AtomicIsize::new(0), }); + if let Some(encryption_opts) = self.pending_encryption_opts.lock().unwrap().take() { + conn.apply_encryption_opts(encryption_opts)?; + } self.n_connections .fetch_add(1, std::sync::atomic::Ordering::SeqCst); let builtin_syms = self.builtin_syms.read(); @@ -662,174 +583,65 @@ impl Database { self.open_flags.get().contains(OpenFlags::ReadOnly) } - /// If we do not have a physical WAL file, but we know the database file is initialized on disk, - /// we need to read the page_size from the database header. - fn read_page_size_from_db_header(&self) -> Result { - turso_assert!( - self.db_state.get().is_initialized(), - "read_page_size_from_db_header called on uninitialized database" - ); - turso_assert!( - PageSize::MIN % 512 == 0, - "header read must be a multiple of 512 for O_DIRECT" - ); - let buf = Arc::new(Buffer::new_temporary(PageSize::MIN as usize)); - let c = Completion::new_read(buf.clone(), move |_res| {}); - let c = self.db_file.read_header(c)?; - self.io.wait_for_completion(c)?; - let page_size = u16::from_be_bytes(buf.as_slice()[16..18].try_into().unwrap()); - let page_size = PageSize::new_from_header_u16(page_size)?; - Ok(page_size) - } - - fn read_reserved_space_bytes_from_db_header(&self) -> Result { - turso_assert!( - self.db_state.get().is_initialized(), - "read_reserved_space_bytes_from_db_header called on uninitialized database" - ); - turso_assert!( - PageSize::MIN % 512 == 0, - "header read must be a multiple of 512 for O_DIRECT" - ); - let buf = Arc::new(Buffer::new_temporary(PageSize::MIN as usize)); - let c = Completion::new_read(buf.clone(), move |_res| {}); - let c = self.db_file.read_header(c)?; - self.io.wait_for_completion(c)?; - let reserved_bytes = u8::from_be_bytes(buf.as_slice()[20..21].try_into().unwrap()); - Ok(reserved_bytes) - } - - /// Read the page size in order of preference: - /// 1. From the WAL header if it exists and is initialized - /// 2. From the database header if the database is initialized - /// - /// Otherwise, fall back to, in order of preference: - /// 1. From the requested page size if it is provided - /// 2. PageSize::default(), i.e. 4096 - fn determine_actual_page_size( - &self, - shared_wal: &WalFileShared, - requested_page_size: Option, - ) -> Result { - if shared_wal.enabled.load(Ordering::SeqCst) { - let size_in_wal = shared_wal.page_size(); - if size_in_wal != 0 { - let Some(page_size) = PageSize::new(size_in_wal) else { - bail_corrupt_error!("invalid page size in WAL: {size_in_wal}"); - }; - return Ok(page_size); - } - } - if self.db_state.get().is_initialized() { - Ok(self.read_page_size_from_db_header()?) - } else { - let Some(size) = requested_page_size else { - return Ok(PageSize::default()); - }; - let Some(page_size) = PageSize::new(size as u32) else { - bail_corrupt_error!("invalid requested page size: {size}"); - }; - Ok(page_size) - } - } - - /// if the database is initialized i.e. it exists on disk, return the reserved space bytes from - /// the header or None - fn maybe_get_reserved_space_bytes(&self) -> Result> { - if self.db_state.get().is_initialized() { - Ok(Some(self.read_reserved_space_bytes_from_db_header()?)) - } else { - Ok(None) - } - } - fn init_pager(&self, requested_page_size: Option) -> Result { - let reserved_bytes = self.maybe_get_reserved_space_bytes()?; - let disable_checksums = if let Some(reserved_bytes) = reserved_bytes { - // if the required reserved bytes for checksums is not present, disable checksums - reserved_bytes != CHECKSUM_REQUIRED_RESERVED_BYTES - } else { - false + let buffer_pool = self.buffer_pool.clone(); + let db_state = self.db_state.clone(); + + let wal_enabled = { + let shared = self.shared_wal.read(); + shared.enabled.load(Ordering::SeqCst) }; - // Check if WAL is enabled - let shared_wal = self.shared_wal.read(); - if shared_wal.enabled.load(Ordering::SeqCst) { - let page_size = self.determine_actual_page_size(&shared_wal, requested_page_size)?; - drop(shared_wal); - - let buffer_pool = self.buffer_pool.clone(); - if self.db_state.get().is_initialized() { - buffer_pool.finalize_with_page_size(page_size.get() as usize)?; - } - let db_state = self.db_state.clone(); - let wal = Rc::new(RefCell::new(WalFile::new( + let initial_wal: Option>> = if wal_enabled { + Some(Rc::new(RefCell::new(WalFile::new( self.io.clone(), self.shared_wal.clone(), buffer_pool.clone(), - ))); - let pager = Pager::new( - self.db_file.clone(), - Some(wal), - self.io.clone(), - Arc::new(RwLock::new(PageCache::default())), - buffer_pool.clone(), - db_state, - self.init_lock.clone(), - )?; - pager.set_page_size(page_size); - if let Some(reserved_bytes) = reserved_bytes { - pager.set_reserved_space_bytes(reserved_bytes); - } - if disable_checksums { - pager.reset_checksum_context(); - } - return Ok(pager); - } - let page_size = self.determine_actual_page_size(&shared_wal, requested_page_size)?; - drop(shared_wal); - - let buffer_pool = self.buffer_pool.clone(); + )))) + } else { + None + }; - if self.db_state.get().is_initialized() { - buffer_pool.finalize_with_page_size(page_size.get() as usize)?; - } + let init_lock = if wal_enabled { + self.init_lock.clone() + } else { + Arc::new(Mutex::new(())) + }; - // No existing WAL; create one. - let db_state = self.db_state.clone(); let mut pager = Pager::new( self.db_file.clone(), - None, + initial_wal.clone(), self.io.clone(), Arc::new(RwLock::new(PageCache::default())), buffer_pool.clone(), db_state, - Arc::new(Mutex::new(())), + init_lock, )?; - pager.set_page_size(page_size); - if let Some(reserved_bytes) = reserved_bytes { - pager.set_reserved_space_bytes(reserved_bytes); - } - if disable_checksums { - pager.reset_checksum_context(); + if !self.db_state.get().is_initialized() { + let page_size = requested_page_size + .and_then(|size| PageSize::new(size as u32)) + .unwrap_or_else(PageSize::default); + pager.set_page_size(page_size); + buffer_pool.finalize_with_page_size(page_size.get() as usize)?; } - let file = self - .io - .open_file(&self.wal_path, OpenFlags::Create, false)?; - // Enable WAL in the existing shared instance - { - let mut shared_wal = self.shared_wal.write(); - shared_wal.create(file)?; - } + if initial_wal.is_none() { + let file = self + .io + .open_file(&self.wal_path, OpenFlags::Create, false)?; + { + let mut shared_wal = self.shared_wal.write(); + shared_wal.create(file)?; + } - let wal = Rc::new(RefCell::new(WalFile::new( - self.io.clone(), - self.shared_wal.clone(), - buffer_pool, - ))); - pager.set_wal(wal); + let wal: Rc> = Rc::new(RefCell::new(WalFile::new( + self.io.clone(), + self.shared_wal.clone(), + buffer_pool, + ))); + pager.set_wal(wal); + } Ok(pager) } @@ -1196,6 +1008,13 @@ impl Drop for Connection { } impl Connection { + fn apply_encryption_opts(self: &Arc, opts: EncryptionOpts) -> Result<()> { + self.pragma_update("cipher", format!("'{}'", opts.cipher))?; + self.pragma_update("hexkey", format!("'{}'", opts.hexkey))?; + self.pager.load().clear_page_cache(false); + Ok(()) + } + /// check if connection executes nested program (so it must not do any "finalization" work as parent program will handle it) pub fn is_nested_stmt(&self) -> bool { self.nestedness.load(Ordering::SeqCst) > 0 @@ -1208,6 +1027,48 @@ impl Connection { pub fn end_nested(&self) { self.nestedness.fetch_add(-1, Ordering::SeqCst); } + + fn ensure_schema_loaded(self: &Arc) -> Result<()> { + if !self.db.db_state.get().is_initialized() { + return Ok(()); + } + let needs_load = { + let schema = self.db.schema.lock().unwrap(); + schema.schema_version == 0 && schema.tables.is_empty() + }; + if !needs_load { + return Ok(()); + } + let pager = self.pager.load().clone(); + let syms = self.syms.read(); + let load_result = self.db.with_schema_mut(|schema| { + if schema.schema_version != 0 || !schema.tables.is_empty() { + return Ok(()); + } + let header_schema_cookie = pager + .io + .block(|| pager.with_header(|header| header.schema_cookie.get()))?; + schema.schema_version = header_schema_cookie; + match schema.make_from_btree(None, pager.clone(), &*syms) { + Err(LimboError::ExtensionError(e)) => { + eprintln!("Warning: {e}"); + } + Err(err) => return Err(err), + Ok(()) => {} + } + if self.mvcc_enabled() && !schema.indexes.is_empty() { + return Err(LimboError::ParseError( + "Database contains indexes which are not supported when MVCC is enabled." + .to_string(), + )); + } + Ok(()) + }); + drop(syms); + load_result?; + Ok(()) + } + pub fn prepare(self: &Arc, sql: impl AsRef) -> Result { if self.is_mvcc_bootstrap_connection() { // Never use MV store for bootstrapping - we read state directly from sqlite_schema in the DB file. @@ -1241,6 +1102,7 @@ impl Connection { let input = str::from_utf8(&sql.as_bytes()[..byte_offset_end]) .unwrap() .trim(); + self.ensure_schema_loaded()?; self.maybe_update_schema(); let pager = self.pager.load().clone(); let mode = QueryMode::new(&cmd); @@ -1390,6 +1252,7 @@ impl Connection { "The supplied SQL string contains no statements".to_string(), )); } + self.ensure_schema_loaded()?; self.maybe_update_schema(); let sql = sql.as_ref(); tracing::trace!("Preparing and executing batch: {}", sql); @@ -1424,6 +1287,7 @@ impl Connection { return Err(LimboError::InternalError("Connection closed".to_string())); } let sql = sql.as_ref(); + self.ensure_schema_loaded()?; self.maybe_update_schema(); tracing::trace!("Querying: {}", sql); let mut parser = Parser::new(sql.as_bytes()); @@ -1451,6 +1315,8 @@ impl Connection { let pager = self.pager.load().clone(); let mode = QueryMode::new(&cmd); let (Cmd::Stmt(stmt) | Cmd::Explain(stmt) | Cmd::ExplainQueryPlan(stmt)) = cmd; + self.ensure_schema_loaded()?; + self.maybe_update_schema(); let program = translate::translate( self.schema.read().deref(), stmt, @@ -1476,6 +1342,7 @@ impl Connection { return Err(LimboError::InternalError("Connection closed".to_string())); } let sql = sql.as_ref(); + self.ensure_schema_loaded()?; self.maybe_update_schema(); let mut parser = Parser::new(sql.as_bytes()); while let Some(cmd) = parser.next_cmd()? { @@ -1504,6 +1371,8 @@ impl Connection { #[instrument(skip_all, level = Level::INFO)] pub fn consume_stmt(self: &Arc, sql: &str) -> Result> { + self.ensure_schema_loaded()?; + self.maybe_update_schema(); let mut parser = Parser::new(sql.as_bytes()); let Some(cmd) = parser.next_cmd()? else { return Ok(None); @@ -1892,7 +1761,15 @@ impl Connection { } pub fn get_page_size(&self) -> PageSize { let value = self.page_size.load(Ordering::SeqCst); - PageSize::new_from_header_u16(value).unwrap_or_default() + if value != 0 { + return PageSize::new_from_header_u16(value).unwrap_or_default(); + } + if let Ok(size) = self.pager.load().get_page_size_checked() { + self.page_size.store(size.get_raw(), Ordering::SeqCst); + size + } else { + PageSize::default() + } } pub fn is_closed(&self) -> bool { diff --git a/core/storage/encryption.rs b/core/storage/encryption.rs index f0184dbc04..9b08d1a012 100644 --- a/core/storage/encryption.rs +++ b/core/storage/encryption.rs @@ -65,7 +65,7 @@ use turso_macros::{match_ignore_ascii_case, AtomicEnum}; /// ``` /// /// constants used for the Turso page header in the encrypted dbs. -const TURSO_HEADER_PREFIX: &[u8] = b"Turso"; +pub const TURSO_HEADER_PREFIX: &[u8] = b"Turso"; const TURSO_VERSION: u8 = 0x00; const VERSION_OFFSET: usize = 5; const CIPHER_OFFSET: usize = 6; diff --git a/core/storage/page_cache.rs b/core/storage/page_cache.rs index e8ac4e657a..913e27bc88 100644 --- a/core/storage/page_cache.rs +++ b/core/storage/page_cache.rs @@ -294,7 +294,7 @@ impl PageCache { // However, if we do not evict that page from the page cache, we will return an unloaded page later which will trigger // assertions later on. This is worsened by the fact that page cache is not per `Statement`, so you can abort a completion // in one Statement, and trigger some error in the next one if we don't evict the page here. - if !page.is_loaded() && !page.is_locked() { + if (!page.is_loaded() || page.get().contents.is_none()) && !page.is_locked() { self.delete(*key)?; return Ok(None); } diff --git a/core/storage/pager.rs b/core/storage/pager.rs index e809673bb2..dc11bcb517 100644 --- a/core/storage/pager.rs +++ b/core/storage/pager.rs @@ -1,3 +1,4 @@ +use crate::storage::checksum::CHECKSUM_REQUIRED_RESERVED_BYTES; use crate::storage::subjournal::Subjournal; use crate::storage::wal::IOV_MAX; use crate::storage::{ @@ -11,8 +12,8 @@ use crate::storage::{ use crate::types::{IOCompletions, WalState}; use crate::util::IOExt as _; use crate::{ - io::CompletionGroup, return_if_io, turso_assert, types::WalFrameInfo, Completion, Connection, - IOResult, LimboError, Result, TransactionState, + io::CompletionGroup, return_if_io, turso_assert, types::WalFrameInfo, Buffer, Completion, + Connection, IOResult, LimboError, Result, TransactionState, }; use crate::{io_yield_one, CompletionError, IOContext, OpenFlags, IO}; use parking_lot::RwLock; @@ -49,11 +50,14 @@ const PENDING_BYTE: u32 = 0x40000000; use ptrmap::*; #[derive(Debug, Clone)] -pub struct HeaderRef(PageRef); +pub struct HeaderRef(HeaderPage); impl HeaderRef { pub fn from_pager(pager: &Pager) -> Result> { loop { + if let Some(guard) = pager.try_cached_header_guard() { + return Ok(IOResult::Done(Self(guard))); + } let state = pager.header_ref_state.read().clone(); tracing::trace!("HeaderRef::from_pager - {:?}", state); match state { @@ -74,8 +78,10 @@ impl HeaderRef { page.get().id == DatabaseHeader::PAGE_ID, "incorrect header page id" ); + Pager::validate_header_page(&page)?; + let guard = pager.store_cached_header_page(page); *pager.header_ref_state.write() = HeaderRefState::Start; - break Ok(IOResult::Done(Self(page))); + break Ok(IOResult::Done(Self(guard))); } } } @@ -83,17 +89,21 @@ impl HeaderRef { pub fn borrow(&self) -> &DatabaseHeader { // TODO: Instead of erasing mutability, implement `get_mut_contents` and return a shared reference. - let content: &PageContent = self.0.get_contents(); + let content: &PageContent = self.0.page().get_contents(); bytemuck::from_bytes::(&content.buffer.as_slice()[0..DatabaseHeader::SIZE]) } } #[derive(Debug, Clone)] -pub struct HeaderRefMut(PageRef); +pub struct HeaderRefMut(HeaderPage); impl HeaderRefMut { pub fn from_pager(pager: &Pager) -> Result> { loop { + if let Some(guard) = pager.try_cached_header_guard() { + pager.add_dirty(guard.page().as_ref())?; + return Ok(IOResult::Done(Self(guard))); + } let state = pager.header_ref_state.read().clone(); tracing::trace!(?state); match state { @@ -114,23 +124,108 @@ impl HeaderRefMut { page.get().id == DatabaseHeader::PAGE_ID, "incorrect header page id" ); - + Pager::validate_header_page(&page)?; pager.add_dirty(&page)?; *pager.header_ref_state.write() = HeaderRefState::Start; - break Ok(IOResult::Done(Self(page))); + let guard = pager.store_cached_header_page(page); + break Ok(IOResult::Done(Self(guard))); } } } } pub fn borrow_mut(&self) -> &mut DatabaseHeader { - let content = self.0.get_contents(); + let content = self.0.page().get_contents(); bytemuck::from_bytes_mut::( &mut content.buffer.as_mut_slice()[0..DatabaseHeader::SIZE], ) } } +#[derive(Debug, Clone)] +pub struct HeaderPage(PageRef); + +impl HeaderPage { + fn new(page: PageRef) -> Self { + Self(page) + } + + fn page(&self) -> &PageRef { + &self.0 + } +} + +#[cfg(test)] +mod header_cache_tests { + use super::*; + use crate::storage::database::DatabaseFile; + use crate::storage::page_cache::PageCache; + use crate::{MemoryIO, OpenFlags}; + use std::sync::{Arc, Mutex}; + + fn create_test_pager() -> Pager { + let io: Arc = Arc::new(MemoryIO::new()); + let file = io + .open_file("header-cache-test", OpenFlags::Create, false) + .unwrap(); + let db_file = Arc::new(DatabaseFile::new(file)); + let buffer_pool = BufferPool::begin_init(&io, BufferPool::DEFAULT_ARENA_SIZE); + buffer_pool + .finalize_with_page_size(PageSize::default().get() as usize) + .unwrap(); + let page_cache = Arc::new(RwLock::new(PageCache::default())); + let db_state = Arc::new(AtomicDbState::new(DbState::Uninitialized)); + let pager = Pager::new( + db_file, + None, + io.clone(), + page_cache, + buffer_pool, + db_state, + Arc::new(Mutex::new(())), + ) + .unwrap(); + pager.set_page_size(PageSize::default()); + pager + } + + fn run_io(pager: &Pager, mut f: impl FnMut() -> Result>) -> T { + loop { + match f().unwrap() { + crate::types::IOResult::Done(v) => break v, + crate::types::IOResult::IO(c) => { + c.wait(pager.io.as_ref()).unwrap(); + } + } + } + } + + #[test] + fn cached_header_invalidated_when_contents_drop() { + let pager = create_test_pager(); + run_io(&pager, || pager.maybe_allocate_page1()); + let header = run_io(&pager, || HeaderRef::from_pager(&pager)); + assert!(pager.header_page.read().is_some()); + header.0.page().get().contents.take(); + drop(header); + assert!(pager.try_cached_header_guard().is_none()); + run_io(&pager, || HeaderRef::from_pager(&pager)); + } + + #[test] + fn cached_header_reused_without_re_read() { + let pager = create_test_pager(); + run_io(&pager, || pager.maybe_allocate_page1()); + + let first = run_io(&pager, || HeaderRef::from_pager(&pager)); + let cached = first.0.page().clone(); + drop(first); + + let second = run_io(&pager, || HeaderRef::from_pager(&pager)); + assert!(Arc::ptr_eq(second.0.page(), &cached)); + } +} + pub struct PageInner { pub flags: AtomicUsize, pub contents: Option, @@ -522,6 +617,7 @@ pub struct Pager { checkpoint_state: RwLock, syncing: Arc, auto_vacuum_mode: AtomicU8, + autovacuum_allowed: AtomicBool, /// 0 -> Database is empty, /// 1 -> Database is being initialized, /// 2 -> Database is initialized and ready for use. @@ -546,6 +642,7 @@ pub struct Pager { /// Maximum number of pages allowed in the database. Default is 1073741823 (SQLite default). max_page_count: AtomicU32, header_ref_state: RwLock, + header_page: RwLock>, #[cfg(not(feature = "omit_autovacuum"))] vacuum_state: RwLock, pub(crate) io_ctx: RwLock, @@ -650,6 +747,7 @@ impl Pager { checkpoint_state: RwLock::new(CheckpointState::Checkpoint), buffer_pool, auto_vacuum_mode: AtomicU8::new(AutoVacuumMode::None.into()), + autovacuum_allowed: AtomicBool::new(true), db_state, init_lock, allocate_page1_state, @@ -660,6 +758,7 @@ impl Pager { allocate_page_state: RwLock::new(AllocatePageState::Start), max_page_count: AtomicU32::new(DEFAULT_MAX_PAGE_COUNT), header_ref_state: RwLock::new(HeaderRefState::Start), + header_page: RwLock::new(None), #[cfg(not(feature = "omit_autovacuum"))] vacuum_state: RwLock::new(VacuumState { ptrmap_get_state: PtrMapGetState::Start, @@ -677,6 +776,121 @@ impl Pager { Ok(()) } + pub(crate) fn allow_autovacuum(&self, allowed: bool) { + self.autovacuum_allowed.store(allowed, Ordering::SeqCst); + } + + fn ensure_page_size_initialized(&self) -> Result<()> { + if self.page_size.load(Ordering::SeqCst) != 0 { + return Ok(()); + } + + if !self.db_state.get().is_initialized() { + let configured = self.page_size.load(Ordering::SeqCst); + let size = if configured == 0 { + PageSize::default().get() + } else { + configured + }; + self.page_size.store(size, Ordering::SeqCst); + self.buffer_pool.finalize_with_page_size(size as usize)?; + return Ok(()); + } + + let buf = self.read_db_header_block()?; + let header_bytes = &buf.as_slice()[0..DatabaseHeader::SIZE]; + DatabaseHeader::validate_bytes(header_bytes)?; + let header = bytemuck::from_bytes::(header_bytes); + + self.page_size + .store(header.page_size.get(), Ordering::SeqCst); + self.buffer_pool + .finalize_with_page_size(header.page_size.get() as usize)?; + self.set_reserved_space(header.reserved_space); + if header.reserved_space != CHECKSUM_REQUIRED_RESERVED_BYTES { + self.reset_checksum_context(); + } + + let header_mode = if header.vacuum_mode_largest_root_page.get() > 0 { + if header.incremental_vacuum_enabled.get() > 0 { + AutoVacuumMode::Incremental + } else { + AutoVacuumMode::Full + } + } else { + AutoVacuumMode::None + }; + + let final_mode = if self.autovacuum_allowed.load(Ordering::SeqCst) { + header_mode + } else { + AutoVacuumMode::None + }; + self.set_auto_vacuum_mode(final_mode); + + Ok(()) + } + + fn read_db_header_block(&self) -> Result> { + let buf = Arc::new(Buffer::new_temporary(PageSize::MIN as usize)); + let completion = Completion::new_read(buf.clone(), move |_res| {}); + let completion = self.db_file.read_header(completion)?; + self.io.wait_for_completion(completion)?; + Ok(buf) + } + + fn try_cached_header_guard(&self) -> Option { + let maybe_page = { + let guard = self.header_page.read(); + guard.as_ref().cloned() + }; + let Some(page) = maybe_page else { + return None; + }; + if page.get().contents.is_some() { + return Some(HeaderPage::new(page)); + } + self.clear_cached_header_page(); + None + } + + fn update_cached_header_page(&self, page: &PageRef) { + let mut guard = self.header_page.write(); + if guard + .as_ref() + .map(|cached| Arc::ptr_eq(cached, page)) + .unwrap_or(false) + { + return; + } + if let Some(old) = guard.take() { + if old.is_pinned() { + old.unpin(); + } + } + page.pin(); + *guard = Some(page.clone()); + } + + fn store_cached_header_page(&self, page: PageRef) -> HeaderPage { + self.update_cached_header_page(&page); + HeaderPage::new(page) + } + + fn clear_cached_header_page(&self) { + let mut guard = self.header_page.write(); + if let Some(page) = guard.take() { + if page.is_pinned() { + page.unpin(); + } + } + } + + fn validate_header_page(page: &PageRef) -> Result<()> { + let content: &PageContent = page.get_contents(); + DatabaseHeader::validate_bytes(&content.buffer.as_slice()[0..DatabaseHeader::SIZE]) + } + /// Open the subjournal if not yet open. /// The subjournal is a file that is used to store the "before images" of pages for the /// current savepoint. If the savepoint is rolled back, the pages can be restored from the subjournal. @@ -1342,6 +1556,16 @@ impl Pager { } } + pub fn get_page_size_checked(&self) -> Result { + if let Some(size) = self.get_page_size() { + return Ok(size); + } + self.ensure_page_size_initialized()?; + self.get_page_size().ok_or_else(|| { + LimboError::InternalError("page size not set after header load".to_string()) + }) + } + /// Get the current page size, panicking if not set. pub fn get_page_size_unchecked(&self) -> PageSize { let value = self.page_size.load(Ordering::SeqCst); @@ -1526,6 +1750,7 @@ impl Pager { frame_watermark: Option, allow_empty_read: bool, ) -> Result<(PageRef, Completion)> { + self.ensure_page_size_initialized()?; assert!(page_idx >= 0); tracing::debug!("read_page_no_cache(page_idx = {})", page_idx); let page = Arc::new(Page::new(page_idx)); @@ -1563,6 +1788,7 @@ impl Pager { /// Reads a page from the database. #[tracing::instrument(skip_all, level = Level::DEBUG)] pub fn read_page(&self, page_idx: i64) -> Result<(PageRef, Option)> { + self.ensure_page_size_initialized()?; assert!(page_idx >= 0, "pages in pager should be positive, negative might indicate unallocated pages from mvcc or any other nasty bug"); tracing::debug!("read_page(page_idx = {})", page_idx); let mut page_cache = self.page_cache.write(); @@ -1806,7 +2032,7 @@ impl Pager { trace!(?state); match state { CommitState::PrepareWal => { - let page_sz = self.get_page_size_unchecked(); + let page_sz = self.get_page_size_checked()?; let c = wal.borrow_mut().prepare_wal_start(page_sz)?; let Some(c) = c else { self.commit_info.write().state = CommitState::GetDbSize; @@ -1844,7 +2070,7 @@ impl Pager { let mut commit_info = self.commit_info.write(); commit_info.completions.clear(); } - let page_sz = self.get_page_size_unchecked(); + let page_sz = self.get_page_size_checked()?; let mut pages: Vec = Vec::with_capacity(dirty_ids.len().min(IOV_MAX)); let total = dirty_ids.len(); let mut cache = self.page_cache.write(); @@ -2120,6 +2346,7 @@ impl Pager { /// of a rollback or in case we want to invalidate page cache after starting a read transaction /// right after new writes happened which would invalidate current page cache. pub fn clear_page_cache(&self, clear_dirty: bool) { + self.clear_cached_header_page(); let dirty_pages = self.dirty_pages.read(); let mut cache = self.page_cache.write(); for page_id in dirty_pages.iter() { @@ -2713,6 +2940,9 @@ impl Pager { })?; page.set_loaded(); page.clear_wal_tag(); + if id == DatabaseHeader::PAGE_ID { + self.update_cached_header_page(&page); + } Ok(()) } @@ -2793,7 +3023,7 @@ impl Pager { )); } - let page_size = self.get_page_size_unchecked().get() as usize; + let page_size = self.get_page_size_checked()?.get() as usize; let encryption_ctx = EncryptionContext::new(cipher_mode, key, page_size)?; { let mut io_ctx = self.io_ctx.write(); diff --git a/core/storage/sqlite3_ondisk.rs b/core/storage/sqlite3_ondisk.rs index 4a1e4e407f..a559974f3c 100644 --- a/core/storage/sqlite3_ondisk.rs +++ b/core/storage/sqlite3_ondisk.rs @@ -59,6 +59,7 @@ use crate::storage::btree::offset::{ use crate::storage::btree::{payload_overflow_threshold_max, payload_overflow_threshold_min}; use crate::storage::buffer_pool::BufferPool; use crate::storage::database::{DatabaseStorage, EncryptionOrChecksum}; +use crate::storage::encryption::TURSO_HEADER_PREFIX; use crate::storage::pager::Pager; use crate::storage::wal::READMARK_NOT_USED; use crate::types::{SerialType, SerialTypeKind, TextSubtype, ValueRef}; @@ -295,6 +296,7 @@ pub struct DatabaseHeader { } impl DatabaseHeader { + pub const MAGIC: [u8; 16] = *b"SQLite format 3\0"; pub const PAGE_ID: usize = 1; pub const SIZE: usize = size_of::(); @@ -305,12 +307,51 @@ impl DatabaseHeader { pub fn usable_space(self) -> usize { (self.page_size.get() as usize) - (self.reserved_space as usize) } + + pub fn validate_bytes(header: &[u8]) -> Result<()> { + if header.len() < Self::SIZE { + return Err(LimboError::NotADB); + } + + let has_sqlite_magic = header.starts_with(&Self::MAGIC); + let has_turso_magic = header.starts_with(TURSO_HEADER_PREFIX); + if !has_sqlite_magic && !has_turso_magic { + return Err(LimboError::NotADB); + } + + let write_version = header[18]; + let read_version = header[19]; + if !matches!(write_version, 1 | 2) || !matches!(read_version, 1 | 2) { + return Err(LimboError::NotADB); + } + + if header[21] != 64 || header[22] != 32 || header[23] != 32 { + return Err(LimboError::NotADB); + } + + let page_size_raw = u16::from_be_bytes([header[16], header[17]]); + let page_size = match PageSize::new_from_header_u16(page_size_raw) { + Ok(size) => size.get() as usize, + Err(_) => return Err(LimboError::NotADB), + }; + + let reserved_bytes = header[20] as usize; + if reserved_bytes >= page_size { + return Err(LimboError::NotADB); + } + + if page_size - reserved_bytes < 480 { + return Err(LimboError::NotADB); + } + + Ok(()) + } } impl Default for DatabaseHeader { fn default() -> Self { Self { - magic: *b"SQLite format 3\0", + magic: Self::MAGIC, page_size: Default::default(), write_version: Version::Wal, read_version: Version::Wal, @@ -1006,7 +1047,7 @@ pub fn write_pages_vectored( return Ok(Vec::new()); } - let page_sz = pager.get_page_size_unchecked().get() as usize; + let page_sz = pager.get_page_size_checked()?.get() as usize; // Count expected number of runs to create the atomic counter we need to track each batch let mut run_count = 0; let mut prev_id = None; diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 1bfc278a9f..79e00e47f1 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -8389,7 +8389,7 @@ pub fn op_integrity_check( crate::storage::pager::AutoVacuumMode::None ) { tracing::debug!("Integrity check: auto-vacuum mode detected ({:?}). Scanning for pointer-map pages.", auto_vacuum_mode); - let page_size = pager.get_page_size_unchecked().get() as usize; + let page_size = pager.get_page_size_checked()?.get() as usize; for page_number in 2..=integrity_check_state.db_size { if crate::storage::pager::ptrmap::is_ptrmap_page( diff --git a/tests/integration/storage/header.rs b/tests/integration/storage/header.rs new file mode 100644 index 0000000000..106a0e5452 --- /dev/null +++ b/tests/integration/storage/header.rs @@ -0,0 +1,29 @@ +use crate::common::maybe_setup_tracing; +use std::sync::Arc; +use tempfile::TempDir; +use turso_core::{Database, DatabaseOpts, LimboError, OpenFlags}; + +#[test] +fn invalid_database_errors_on_first_query() -> anyhow::Result<()> { + maybe_setup_tracing(); + let dir = TempDir::new()?; + let db_path = dir.path().join("invalid.db"); + std::fs::write(&db_path, b"definitely not a database")?; + + let io: Arc = Arc::new(turso_core::PlatformIO::new()?); + let db = Database::open_file_with_flags( + io, + db_path.to_str().unwrap(), + OpenFlags::ReadOnly, + DatabaseOpts::new(), + None, + )?; + + let conn = db.connect()?; + let err = conn + .prepare("SELECT name FROM sqlite_master LIMIT 1") + .unwrap_err(); + assert!(matches!(err, LimboError::NotADB)); + + Ok(()) +} diff --git a/tests/integration/storage/mod.rs b/tests/integration/storage/mod.rs index eb366ee729..b2a24f0e9d 100644 --- a/tests/integration/storage/mod.rs +++ b/tests/integration/storage/mod.rs @@ -1,2 +1,3 @@ #[cfg(feature = "checksum")] mod checksum; +mod header;