Skip to content

Commit 9f13e0f

Browse files
committed
Validate page 1 header
1 parent 2c49c47 commit 9f13e0f

File tree

4 files changed

+122
-22
lines changed

4 files changed

+122
-22
lines changed

core/lib.rs

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ pub use storage::database::IOContext;
9090
pub use storage::encryption::{CipherMode, EncryptionContext, EncryptionKey};
9191
use storage::page_cache::PageCache;
9292
use storage::pager::{AtomicDbState, DbState};
93-
use storage::sqlite3_ondisk::PageSize;
93+
use storage::sqlite3_ondisk::{DatabaseHeader, PageSize};
9494
pub use storage::{
9595
buffer_pool::BufferPool,
9696
database::DatabaseStorage,
@@ -662,41 +662,36 @@ impl Database {
662662
self.open_flags.get().contains(OpenFlags::ReadOnly)
663663
}
664664

665-
/// If we do not have a physical WAL file, but we know the database file is initialized on disk,
666-
/// we need to read the page_size from the database header.
667-
fn read_page_size_from_db_header(&self) -> Result<PageSize> {
665+
/// Read the first page-size block from disk and verify it looks like a database header.
666+
fn read_db_header_block(&self) -> Result<Arc<Buffer>> {
668667
turso_assert!(
669668
self.db_state.get().is_initialized(),
670-
"read_page_size_from_db_header called on uninitialized database"
669+
"read_db_header_block called on uninitialized database"
671670
);
672671
turso_assert!(
673-
PageSize::MIN % 512 == 0,
672+
PageSize::MIN.is_multiple_of(512),
674673
"header read must be a multiple of 512 for O_DIRECT"
675674
);
676675
let buf = Arc::new(Buffer::new_temporary(PageSize::MIN as usize));
677676
let c = Completion::new_read(buf.clone(), move |_res| {});
678677
let c = self.db_file.read_header(c)?;
679678
self.io.wait_for_completion(c)?;
679+
DatabaseHeader::validate_bytes(buf.as_slice())?;
680+
Ok(buf)
681+
}
682+
683+
/// If we do not have a physical WAL file, but we know the database file is initialized on disk,
684+
/// we need to read the page_size from the database header.
685+
fn read_page_size_from_db_header(&self) -> Result<PageSize> {
686+
let buf = self.read_db_header_block()?;
680687
let page_size = u16::from_be_bytes(buf.as_slice()[16..18].try_into().unwrap());
681688
let page_size = PageSize::new_from_header_u16(page_size)?;
682689
Ok(page_size)
683690
}
684691

685692
fn read_reserved_space_bytes_from_db_header(&self) -> Result<u8> {
686-
turso_assert!(
687-
self.db_state.get().is_initialized(),
688-
"read_reserved_space_bytes_from_db_header called on uninitialized database"
689-
);
690-
turso_assert!(
691-
PageSize::MIN % 512 == 0,
692-
"header read must be a multiple of 512 for O_DIRECT"
693-
);
694-
let buf = Arc::new(Buffer::new_temporary(PageSize::MIN as usize));
695-
let c = Completion::new_read(buf.clone(), move |_res| {});
696-
let c = self.db_file.read_header(c)?;
697-
self.io.wait_for_completion(c)?;
698-
let reserved_bytes = u8::from_be_bytes(buf.as_slice()[20..21].try_into().unwrap());
699-
Ok(reserved_bytes)
693+
let buf = self.read_db_header_block()?;
694+
Ok(buf.as_slice()[20])
700695
}
701696

702697
/// Read the page size in order of preference:

core/storage/encryption.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ use turso_macros::{match_ignore_ascii_case, AtomicEnum};
6565
/// ```
6666
///
6767
/// constants used for the Turso page header in the encrypted dbs.
68-
const TURSO_HEADER_PREFIX: &[u8] = b"Turso";
68+
pub const TURSO_HEADER_PREFIX: &[u8] = b"Turso";
6969
const TURSO_VERSION: u8 = 0x00;
7070
const VERSION_OFFSET: usize = 5;
7171
const CIPHER_OFFSET: usize = 6;

core/storage/pager.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ pub struct HeaderRef(PageRef);
5555
impl HeaderRef {
5656
pub fn from_pager(pager: &Pager) -> Result<IOResult<Self>> {
5757
loop {
58+
if let Some(page) = pager.cached_header_page() {
59+
return Ok(IOResult::Done(Self(page)));
60+
}
5861
let state = pager.header_ref_state.read().clone();
5962
tracing::trace!("HeaderRef::from_pager - {:?}", state);
6063
match state {
@@ -75,6 +78,8 @@ impl HeaderRef {
7578
page.get().id == DatabaseHeader::PAGE_ID,
7679
"incorrect header page id"
7780
);
81+
Pager::validate_header_page(&page)?;
82+
pager.set_cached_header_page(&page);
7883
*pager.header_ref_state.write() = HeaderRefState::Start;
7984
break Ok(IOResult::Done(Self(page)));
8085
}
@@ -95,6 +100,10 @@ pub struct HeaderRefMut(PageRef);
95100
impl HeaderRefMut {
96101
pub fn from_pager(pager: &Pager) -> Result<IOResult<Self>> {
97102
loop {
103+
if let Some(page) = pager.cached_header_page() {
104+
pager.add_dirty(page.as_ref())?;
105+
return Ok(IOResult::Done(Self(page)));
106+
}
98107
let state = pager.header_ref_state.read().clone();
99108
tracing::trace!(?state);
100109
match state {
@@ -115,6 +124,8 @@ impl HeaderRefMut {
115124
page.get().id == DatabaseHeader::PAGE_ID,
116125
"incorrect header page id"
117126
);
127+
Pager::validate_header_page(&page)?;
128+
pager.set_cached_header_page(&page);
118129

119130
pager.add_dirty(&page)?;
120131
*pager.header_ref_state.write() = HeaderRefState::Start;
@@ -547,6 +558,7 @@ pub struct Pager {
547558
/// Maximum number of pages allowed in the database. Default is 1073741823 (SQLite default).
548559
max_page_count: AtomicU32,
549560
header_ref_state: RwLock<HeaderRefState>,
561+
header_page: RwLock<Option<PageRef>>,
550562
#[cfg(not(feature = "omit_autovacuum"))]
551563
vacuum_state: RwLock<VacuumState>,
552564
pub(crate) io_ctx: RwLock<IOContext>,
@@ -661,6 +673,7 @@ impl Pager {
661673
allocate_page_state: RwLock::new(AllocatePageState::Start),
662674
max_page_count: AtomicU32::new(DEFAULT_MAX_PAGE_COUNT),
663675
header_ref_state: RwLock::new(HeaderRefState::Start),
676+
header_page: RwLock::new(None),
664677
#[cfg(not(feature = "omit_autovacuum"))]
665678
vacuum_state: RwLock::new(VacuumState {
666679
ptrmap_get_state: PtrMapGetState::Start,
@@ -678,6 +691,53 @@ impl Pager {
678691
Ok(())
679692
}
680693

694+
fn cached_header_page(&self) -> Option<PageRef> {
695+
let maybe_page = self.header_page.read().clone();
696+
if let Some(page) = maybe_page {
697+
if page.get().contents.is_some() {
698+
Some(page)
699+
} else {
700+
drop(page);
701+
self.clear_cached_header_page();
702+
None
703+
}
704+
} else {
705+
None
706+
}
707+
}
708+
709+
fn set_cached_header_page(&self, page: &PageRef) {
710+
let mut guard = self.header_page.write();
711+
if guard
712+
.as_ref()
713+
.map(|cached| Arc::ptr_eq(cached, page))
714+
.unwrap_or(false)
715+
{
716+
return;
717+
}
718+
if let Some(old) = guard.take() {
719+
if old.is_pinned() {
720+
old.unpin();
721+
}
722+
}
723+
page.pin();
724+
*guard = Some(page.clone());
725+
}
726+
727+
fn clear_cached_header_page(&self) {
728+
let mut guard = self.header_page.write();
729+
if let Some(page) = guard.take() {
730+
if page.is_pinned() {
731+
page.unpin();
732+
}
733+
}
734+
}
735+
736+
fn validate_header_page(page: &PageRef) -> Result<()> {
737+
let content: &PageContent = page.get_contents();
738+
DatabaseHeader::validate_bytes(&content.buffer.as_slice()[0..DatabaseHeader::SIZE])
739+
}
740+
681741
/// Open the subjournal if not yet open.
682742
/// The subjournal is a file that is used to store the "before images" of pages for the
683743
/// current savepoint. If the savepoint is rolled back, the pages can be restored from the subjournal.
@@ -2120,6 +2180,7 @@ impl Pager {
21202180
/// of a rollback or in case we want to invalidate page cache after starting a read transaction
21212181
/// right after new writes happened which would invalidate current page cache.
21222182
pub fn clear_page_cache(&self, clear_dirty: bool) {
2183+
self.clear_cached_header_page();
21232184
let dirty_pages = self.dirty_pages.read();
21242185
let mut cache = self.page_cache.write();
21252186
for page_id in dirty_pages.iter() {
@@ -2713,6 +2774,9 @@ impl Pager {
27132774
})?;
27142775
page.set_loaded();
27152776
page.clear_wal_tag();
2777+
if id == DatabaseHeader::PAGE_ID {
2778+
self.set_cached_header_page(&page);
2779+
}
27162780
Ok(())
27172781
}
27182782

core/storage/sqlite3_ondisk.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ use crate::storage::btree::offset::{
5959
use crate::storage::btree::{payload_overflow_threshold_max, payload_overflow_threshold_min};
6060
use crate::storage::buffer_pool::BufferPool;
6161
use crate::storage::database::{DatabaseFile, DatabaseStorage, EncryptionOrChecksum};
62+
use crate::storage::encryption::TURSO_HEADER_PREFIX;
6263
use crate::storage::pager::Pager;
6364
use crate::storage::wal::READMARK_NOT_USED;
6465
use crate::types::{SerialType, SerialTypeKind, TextSubtype, ValueRef};
@@ -295,6 +296,7 @@ pub struct DatabaseHeader {
295296
}
296297

297298
impl DatabaseHeader {
299+
pub const MAGIC: [u8; 16] = *b"SQLite format 3\0";
298300
pub const PAGE_ID: usize = 1;
299301
pub const SIZE: usize = size_of::<Self>();
300302

@@ -305,12 +307,51 @@ impl DatabaseHeader {
305307
pub fn usable_space(self) -> usize {
306308
(self.page_size.get() as usize) - (self.reserved_space as usize)
307309
}
310+
311+
pub fn validate_bytes(header: &[u8]) -> Result<()> {
312+
if header.len() < Self::SIZE {
313+
return Err(LimboError::NotADB);
314+
}
315+
316+
let has_sqlite_magic = header.starts_with(&Self::MAGIC);
317+
let has_turso_magic = header.starts_with(TURSO_HEADER_PREFIX);
318+
if !has_sqlite_magic && !has_turso_magic {
319+
return Err(LimboError::NotADB);
320+
}
321+
322+
let write_version = header[18];
323+
let read_version = header[19];
324+
if !matches!(write_version, 1 | 2) || !matches!(read_version, 1 | 2) {
325+
return Err(LimboError::NotADB);
326+
}
327+
328+
if header[21] != 64 || header[22] != 32 || header[23] != 32 {
329+
return Err(LimboError::NotADB);
330+
}
331+
332+
let page_size_raw = u16::from_be_bytes([header[16], header[17]]);
333+
let page_size = match PageSize::new_from_header_u16(page_size_raw) {
334+
Ok(size) => size.get() as usize,
335+
Err(_) => return Err(LimboError::NotADB),
336+
};
337+
338+
let reserved_bytes = header[20] as usize;
339+
if reserved_bytes >= page_size {
340+
return Err(LimboError::NotADB);
341+
}
342+
343+
if page_size - reserved_bytes < 480 {
344+
return Err(LimboError::NotADB);
345+
}
346+
347+
Ok(())
348+
}
308349
}
309350

310351
impl Default for DatabaseHeader {
311352
fn default() -> Self {
312353
Self {
313-
magic: *b"SQLite format 3\0",
354+
magic: Self::MAGIC,
314355
page_size: Default::default(),
315356
write_version: Version::Wal,
316357
read_version: Version::Wal,

0 commit comments

Comments
 (0)