From 0685193b8d3e63d96a5b37d457d1540a01629ff9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20G=C3=B6ttsche?= Date: Sat, 26 Oct 2024 16:40:50 +0200 Subject: [PATCH] Avoid ID clashes Currently the the ID for a new paste is randomly generated in the caller of the database insert() function. Then the insert() function tries to insert a new row into the database with that passed ID. There can however already exists a paste in the database with the same ID leading to an insert failure, due to a constraint violation due to the PRIMARY KEY attribute. Checking prior the the INSERT via a SELECT query would open the window for a race condition. A failure to push a new paste is quite severe, since the user might have spent some some to format the input. Generate the ID in a loop inside, until the INSERT succeeds. --- src/db.rs | 87 ++++++++++++++++++++++++++++++++-------------- src/id.rs | 8 ++--- src/routes/form.rs | 32 +++++++---------- src/routes/json.rs | 18 +++------- 4 files changed, 79 insertions(+), 66 deletions(-) diff --git a/src/db.rs b/src/db.rs index 044a4be..1ece397 100644 --- a/src/db.rs +++ b/src/db.rs @@ -259,33 +259,69 @@ impl Database { Ok(Self { conn }) } - /// Insert `entry` under `id` into the database and optionally set owner to `uid`. - pub async fn insert(&self, id: Id, entry: write::Entry) -> Result<(), Error> { + /// Insert `entry` with a new generated `id` into the database and optionally set owner to `uid`. + pub async fn insert(&self, entry: write::Entry) -> Result { let conn = self.conn.clone(); - let id = id.as_u32(); let write::DatabaseEntry { entry, data, nonce } = entry.compress().await?.encrypt().await?; - spawn_blocking(move || match entry.expires { - None => conn.lock().execute( - "INSERT INTO entries (id, uid, data, burn_after_reading, nonce, title) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", - params![id, entry.uid, data, entry.burn_after_reading, nonce, entry.title], - ), - Some(expires) => conn.lock().execute( - "INSERT INTO entries (id, uid, data, burn_after_reading, nonce, expires, title) VALUES (?1, ?2, ?3, ?4, ?5, datetime('now', ?6), ?7)", - params![ - id, - entry.uid, - data, - entry.burn_after_reading, - nonce, - format!("{expires} seconds"), - entry.title, - ], - ), + let id = spawn_blocking(move || { + const COUNTER_LIMIT: u32 = 10; + let mut counter = 0; + + let mut rng = rand::thread_rng(); + + loop { + let id: Id = rand::Rng::gen::(&mut rng).into(); + let id_inner = id.as_u32(); + + let result = match entry.expires { + None => conn.lock().execute( + "INSERT INTO entries (id, uid, data, burn_after_reading, nonce, title) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", + params![id_inner, entry.uid, data, entry.burn_after_reading, nonce, entry.title], + ), + Some(expires) => conn.lock().execute( + "INSERT INTO entries (id, uid, data, burn_after_reading, nonce, expires, title) VALUES (?1, ?2, ?3, ?4, ?5, datetime('now', ?6), ?7)", + params![ + id_inner, + entry.uid, + data, + entry.burn_after_reading, + nonce, + format!("{expires} seconds"), + entry.title, + ], + ), + }; + + match result { + Err(rusqlite::Error::SqliteFailure(rusqlite::ffi::Error { code, extended_code }, Some(ref _message))) + if code == rusqlite::ErrorCode::ConstraintViolation && extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_PRIMARYKEY && counter < COUNTER_LIMIT => { + /* Retry if ID is already existent */ + counter += 1; + continue; + }, + Err(err) => { + if counter >= COUNTER_LIMIT { + tracing::error!("Failed to generate ID after {counter} retries"); + } + + break Err(err) + }, + Ok(rows) => { + debug_assert!(rows == 1); + + if counter > 4 { + tracing::warn!("Required {counter} retries to generate new ID"); + } + + break Ok(id) + }, + } + } }) .await??; - Ok(()) + Ok(id) } /// Get entire entry for `id`. @@ -397,8 +433,7 @@ mod tests { ..Default::default() }; - let id = Id::from(1234); - db.insert(id, entry).await?; + let id = db.insert(entry).await?; let entry = db.get(id, None).await?; assert_eq!(entry.text, "hello world"); @@ -420,8 +455,7 @@ mod tests { ..Default::default() }; - let id = Id::from(1234); - db.insert(id, entry).await?; + let id = db.insert(entry).await?; tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; @@ -436,8 +470,7 @@ mod tests { async fn delete() -> Result<(), Box> { let db = new_db()?; - let id = Id::from(1234); - db.insert(id, write::Entry::default()).await?; + let id = db.insert(write::Entry::default()).await?; assert!(db.get(id, None).await.is_ok()); assert!(db.delete(id).await.is_ok()); diff --git a/src/id.rs b/src/id.rs index 9df9a97..3197e2a 100644 --- a/src/id.rs +++ b/src/id.rs @@ -1,4 +1,3 @@ -use crate::db::write::Entry; use crate::errors::Error; use std::fmt; use std::str::FromStr; @@ -23,11 +22,8 @@ impl Id { } /// Generate a URL path from the string representation and `entry`'s extension. - pub fn to_url_path(self, entry: &Entry) -> String { - entry - .extension - .as_ref() - .map_or_else(|| format!("{self}"), |ext| format!("{self}.{ext}")) + pub fn to_url_path(self, extension: Option<&str>) -> String { + extension.map_or_else(|| format!("{self}"), |ext| format!("{self}.{ext}")) } } diff --git a/src/routes/form.rs b/src/routes/form.rs index 846759e..861631e 100644 --- a/src/routes/form.rs +++ b/src/routes/form.rs @@ -2,12 +2,10 @@ use std::num::NonZeroU32; use crate::db::write; use crate::env::BASE_PATH; -use crate::id::Id; use crate::{pages, AppState, Error}; use axum::extract::{Form, State}; use axum::response::Redirect; use axum_extra::extract::cookie::{Cookie, SameSite, SignedCookieJar}; -use rand::Rng; use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] @@ -48,14 +46,6 @@ pub async fn insert( Form(entry): Form, is_https: bool, ) -> Result<(SignedCookieJar, Redirect), pages::ErrorResponse<'static>> { - let id: Id = tokio::task::spawn_blocking(|| { - let mut rng = rand::thread_rng(); - rng.gen::() - }) - .await - .map_err(Error::from)? - .into(); - // Retrieve uid from cookie or generate a new one. let uid = if let Some(cookie) = jar.get("uid") { cookie @@ -69,22 +59,24 @@ pub async fn insert( let mut entry: write::Entry = entry.into(); entry.uid = Some(uid); - let mut url = id.to_url_path(&entry); - - let burn_after_reading = entry.burn_after_reading.unwrap_or(false); - if burn_after_reading { - url = format!("burn/{url}"); - } - - let url_with_base = BASE_PATH.join(&url); - if let Some(max_exp) = state.max_expiration { entry.expires = entry .expires .map_or_else(|| Some(max_exp), |value| Some(value.min(max_exp))); } - state.db.insert(id, entry).await?; + let burn = entry.burn_after_reading.unwrap_or(false); + let extension = entry.extension.clone(); + + let id = state.db.insert(entry).await?; + + let mut url = id.to_url_path(extension.as_deref()); + + if burn { + url = format!("burn/{url}"); + } + + let url_with_base = BASE_PATH.join(&url); let cookie = Cookie::build(("uid", uid.to_string())) .http_only(true) diff --git a/src/routes/json.rs b/src/routes/json.rs index 698e02a..0c162f5 100644 --- a/src/routes/json.rs +++ b/src/routes/json.rs @@ -2,12 +2,10 @@ use std::num::NonZeroU32; use crate::db::write; use crate::env::BASE_PATH; -use crate::errors::{Error, JsonErrorResponse}; -use crate::id::Id; +use crate::errors::JsonErrorResponse; use crate::AppState; use axum::extract::State; use axum::Json; -use rand::Rng; use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] @@ -43,14 +41,6 @@ pub async fn insert( state: State, Json(entry): Json, ) -> Result, JsonErrorResponse> { - let id: Id = tokio::task::spawn_blocking(|| { - let mut rng = rand::thread_rng(); - rng.gen::() - }) - .await - .map_err(Error::from)? - .into(); - let mut entry: write::Entry = entry.into(); if let Some(max_exp) = state.max_expiration { @@ -59,9 +49,11 @@ pub async fn insert( .map_or_else(|| Some(max_exp), |value| Some(value.min(max_exp))); } - let url = id.to_url_path(&entry); + let extension = entry.extension.clone(); + + let id = state.db.insert(entry).await?; + let url = id.to_url_path(extension.as_deref()); let path = BASE_PATH.join(&url); - state.db.insert(id, entry).await?; Ok(Json::from(RedirectResponse { path })) }