Skip to content

Commit caf528e

Browse files
committed
Validate page 1 header
1 parent 36c3489 commit caf528e

File tree

4 files changed

+123
-23
lines changed

4 files changed

+123
-23
lines changed

core/lib.rs

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ pub use storage::database::IOContext;
9090
pub use storage::encryption::{CipherMode, EncryptionContext, EncryptionKey};
9191
use storage::page_cache::PageCache;
9292
use storage::pager::{AtomicDbState, DbState};
93-
use storage::sqlite3_ondisk::PageSize;
93+
use storage::sqlite3_ondisk::{DatabaseHeader, PageSize};
9494
pub use storage::{
9595
buffer_pool::BufferPool,
9696
database::DatabaseStorage,
@@ -662,41 +662,36 @@ impl Database {
662662
self.open_flags.get().contains(OpenFlags::ReadOnly)
663663
}
664664

665-
/// If we do not have a physical WAL file, but we know the database file is initialized on disk,
666-
/// we need to read the page_size from the database header.
667-
fn read_page_size_from_db_header(&self) -> Result<PageSize> {
665+
/// Read the first page-size block from disk and verify it looks like a database header.
666+
fn read_db_header_block(&self) -> Result<Arc<Buffer>> {
668667
turso_assert!(
669668
self.db_state.get().is_initialized(),
670-
"read_page_size_from_db_header called on uninitialized database"
669+
"read_db_header_block called on uninitialized database"
671670
);
672671
turso_assert!(
673-
PageSize::MIN % 512 == 0,
672+
PageSize::MIN.is_multiple_of(512),
674673
"header read must be a multiple of 512 for O_DIRECT"
675674
);
676675
let buf = Arc::new(Buffer::new_temporary(PageSize::MIN as usize));
677676
let c = Completion::new_read(buf.clone(), move |_res| {});
678677
let c = self.db_file.read_header(c)?;
679678
self.io.wait_for_completion(c)?;
679+
DatabaseHeader::validate_bytes(buf.as_slice())?;
680+
Ok(buf)
681+
}
682+
683+
/// If we do not have a physical WAL file, but we know the database file is initialized on disk,
684+
/// we need to read the page_size from the database header.
685+
fn read_page_size_from_db_header(&self) -> Result<PageSize> {
686+
let buf = self.read_db_header_block()?;
680687
let page_size = u16::from_be_bytes(buf.as_slice()[16..18].try_into().unwrap());
681688
let page_size = PageSize::new_from_header_u16(page_size)?;
682689
Ok(page_size)
683690
}
684691

685692
fn read_reserved_space_bytes_from_db_header(&self) -> Result<u8> {
686-
turso_assert!(
687-
self.db_state.get().is_initialized(),
688-
"read_reserved_space_bytes_from_db_header called on uninitialized database"
689-
);
690-
turso_assert!(
691-
PageSize::MIN % 512 == 0,
692-
"header read must be a multiple of 512 for O_DIRECT"
693-
);
694-
let buf = Arc::new(Buffer::new_temporary(PageSize::MIN as usize));
695-
let c = Completion::new_read(buf.clone(), move |_res| {});
696-
let c = self.db_file.read_header(c)?;
697-
self.io.wait_for_completion(c)?;
698-
let reserved_bytes = u8::from_be_bytes(buf.as_slice()[20..21].try_into().unwrap());
699-
Ok(reserved_bytes)
693+
let buf = self.read_db_header_block()?;
694+
Ok(buf.as_slice()[20])
700695
}
701696

702697
/// Read the page size in order of preference:

core/storage/encryption.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ use turso_macros::{match_ignore_ascii_case, AtomicEnum};
6565
/// ```
6666
///
6767
/// constants used for the Turso page header in the encrypted dbs.
68-
const TURSO_HEADER_PREFIX: &[u8] = b"Turso";
68+
pub const TURSO_HEADER_PREFIX: &[u8] = b"Turso";
6969
const TURSO_VERSION: u8 = 0x00;
7070
const VERSION_OFFSET: usize = 5;
7171
const CIPHER_OFFSET: usize = 6;

core/storage/pager.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ pub struct HeaderRef(PageRef);
5454
impl HeaderRef {
5555
pub fn from_pager(pager: &Pager) -> Result<IOResult<Self>> {
5656
loop {
57+
if let Some(page) = pager.cached_header_page() {
58+
return Ok(IOResult::Done(Self(page)));
59+
}
5760
let state = pager.header_ref_state.read().clone();
5861
tracing::trace!("HeaderRef::from_pager - {:?}", state);
5962
match state {
@@ -74,6 +77,8 @@ impl HeaderRef {
7477
page.get().id == DatabaseHeader::PAGE_ID,
7578
"incorrect header page id"
7679
);
80+
Pager::validate_header_page(&page)?;
81+
pager.set_cached_header_page(&page);
7782
*pager.header_ref_state.write() = HeaderRefState::Start;
7883
break Ok(IOResult::Done(Self(page)));
7984
}
@@ -94,6 +99,10 @@ pub struct HeaderRefMut(PageRef);
9499
impl HeaderRefMut {
95100
pub fn from_pager(pager: &Pager) -> Result<IOResult<Self>> {
96101
loop {
102+
if let Some(page) = pager.cached_header_page() {
103+
pager.add_dirty(page.as_ref())?;
104+
return Ok(IOResult::Done(Self(page)));
105+
}
97106
let state = pager.header_ref_state.read().clone();
98107
tracing::trace!(?state);
99108
match state {
@@ -114,6 +123,8 @@ impl HeaderRefMut {
114123
page.get().id == DatabaseHeader::PAGE_ID,
115124
"incorrect header page id"
116125
);
126+
Pager::validate_header_page(&page)?;
127+
pager.set_cached_header_page(&page);
117128

118129
pager.add_dirty(&page)?;
119130
*pager.header_ref_state.write() = HeaderRefState::Start;
@@ -546,6 +557,7 @@ pub struct Pager {
546557
/// Maximum number of pages allowed in the database. Default is 1073741823 (SQLite default).
547558
max_page_count: AtomicU32,
548559
header_ref_state: RwLock<HeaderRefState>,
560+
header_page: RwLock<Option<PageRef>>,
549561
#[cfg(not(feature = "omit_autovacuum"))]
550562
vacuum_state: RwLock<VacuumState>,
551563
pub(crate) io_ctx: RwLock<IOContext>,
@@ -660,6 +672,7 @@ impl Pager {
660672
allocate_page_state: RwLock::new(AllocatePageState::Start),
661673
max_page_count: AtomicU32::new(DEFAULT_MAX_PAGE_COUNT),
662674
header_ref_state: RwLock::new(HeaderRefState::Start),
675+
header_page: RwLock::new(None),
663676
#[cfg(not(feature = "omit_autovacuum"))]
664677
vacuum_state: RwLock::new(VacuumState {
665678
ptrmap_get_state: PtrMapGetState::Start,
@@ -677,6 +690,53 @@ impl Pager {
677690
Ok(())
678691
}
679692

693+
fn cached_header_page(&self) -> Option<PageRef> {
694+
let maybe_page = self.header_page.read().clone();
695+
if let Some(page) = maybe_page {
696+
if page.get().contents.is_some() {
697+
Some(page)
698+
} else {
699+
drop(page);
700+
self.clear_cached_header_page();
701+
None
702+
}
703+
} else {
704+
None
705+
}
706+
}
707+
708+
fn set_cached_header_page(&self, page: &PageRef) {
709+
let mut guard = self.header_page.write();
710+
if guard
711+
.as_ref()
712+
.map(|cached| Arc::ptr_eq(cached, page))
713+
.unwrap_or(false)
714+
{
715+
return;
716+
}
717+
if let Some(old) = guard.take() {
718+
if old.is_pinned() {
719+
old.unpin();
720+
}
721+
}
722+
page.pin();
723+
*guard = Some(page.clone());
724+
}
725+
726+
fn clear_cached_header_page(&self) {
727+
let mut guard = self.header_page.write();
728+
if let Some(page) = guard.take() {
729+
if page.is_pinned() {
730+
page.unpin();
731+
}
732+
}
733+
}
734+
735+
fn validate_header_page(page: &PageRef) -> Result<()> {
736+
let content: &PageContent = page.get_contents();
737+
DatabaseHeader::validate_bytes(&content.buffer.as_slice()[0..DatabaseHeader::SIZE])
738+
}
739+
680740
/// Open the subjournal if not yet open.
681741
/// The subjournal is a file that is used to store the "before images" of pages for the
682742
/// current savepoint. If the savepoint is rolled back, the pages can be restored from the subjournal.
@@ -2120,6 +2180,7 @@ impl Pager {
21202180
/// of a rollback or in case we want to invalidate page cache after starting a read transaction
21212181
/// right after new writes happened which would invalidate current page cache.
21222182
pub fn clear_page_cache(&self, clear_dirty: bool) {
2183+
self.clear_cached_header_page();
21232184
let dirty_pages = self.dirty_pages.read();
21242185
let mut cache = self.page_cache.write();
21252186
for page_id in dirty_pages.iter() {
@@ -2713,6 +2774,9 @@ impl Pager {
27132774
})?;
27142775
page.set_loaded();
27152776
page.clear_wal_tag();
2777+
if id == DatabaseHeader::PAGE_ID {
2778+
self.set_cached_header_page(&page);
2779+
}
27162780
Ok(())
27172781
}
27182782

core/storage/sqlite3_ondisk.rs

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ use crate::storage::btree::offset::{
5858
};
5959
use crate::storage::btree::{payload_overflow_threshold_max, payload_overflow_threshold_min};
6060
use crate::storage::buffer_pool::BufferPool;
61-
use crate::storage::database::{DatabaseStorage, EncryptionOrChecksum};
61+
use crate::storage::database::{DatabaseFile, DatabaseStorage, EncryptionOrChecksum};
62+
use crate::storage::encryption::TURSO_HEADER_PREFIX;
6263
use crate::storage::pager::Pager;
6364
use crate::storage::wal::READMARK_NOT_USED;
6465
use crate::types::{SerialType, SerialTypeKind, TextSubtype, ValueRef};
@@ -295,6 +296,7 @@ pub struct DatabaseHeader {
295296
}
296297

297298
impl DatabaseHeader {
299+
pub const MAGIC: [u8; 16] = *b"SQLite format 3\0";
298300
pub const PAGE_ID: usize = 1;
299301
pub const SIZE: usize = size_of::<Self>();
300302

@@ -305,12 +307,51 @@ impl DatabaseHeader {
305307
pub fn usable_space(self) -> usize {
306308
(self.page_size.get() as usize) - (self.reserved_space as usize)
307309
}
310+
311+
pub fn validate_bytes(header: &[u8]) -> Result<()> {
312+
if header.len() < Self::SIZE {
313+
return Err(LimboError::NotADB);
314+
}
315+
316+
let has_sqlite_magic = header.starts_with(&Self::MAGIC);
317+
let has_turso_magic = header.starts_with(TURSO_HEADER_PREFIX);
318+
if !has_sqlite_magic && !has_turso_magic {
319+
return Err(LimboError::NotADB);
320+
}
321+
322+
let write_version = header[18];
323+
let read_version = header[19];
324+
if !matches!(write_version, 1 | 2) || !matches!(read_version, 1 | 2) {
325+
return Err(LimboError::NotADB);
326+
}
327+
328+
if header[21] != 64 || header[22] != 32 || header[23] != 32 {
329+
return Err(LimboError::NotADB);
330+
}
331+
332+
let page_size_raw = u16::from_be_bytes([header[16], header[17]]);
333+
let page_size = match PageSize::new_from_header_u16(page_size_raw) {
334+
Ok(size) => size.get() as usize,
335+
Err(_) => return Err(LimboError::NotADB),
336+
};
337+
338+
let reserved_bytes = header[20] as usize;
339+
if reserved_bytes >= page_size {
340+
return Err(LimboError::NotADB);
341+
}
342+
343+
if page_size - reserved_bytes < 480 {
344+
return Err(LimboError::NotADB);
345+
}
346+
347+
Ok(())
348+
}
308349
}
309350

310351
impl Default for DatabaseHeader {
311352
fn default() -> Self {
312353
Self {
313-
magic: *b"SQLite format 3\0",
354+
magic: Self::MAGIC,
314355
page_size: Default::default(),
315356
write_version: Version::Wal,
316357
read_version: Version::Wal,

0 commit comments

Comments
 (0)