Skip to content

Commit

Permalink
Avoid ID clashes
Browse files Browse the repository at this point in the history
Currently the the ID for a new paste is randomly generated in the caller
of the database insert() function.  Then the insert() function tries to
insert a new row into the database with that passed ID.  There can
however already exists a paste in the database with the same ID leading
to an insert failure, due to a constraint violation due to the PRIMARY
KEY attribute.  Checking prior the the INSERT via a SELECT query would
open the window for a race condition.

A failure to push a new paste is quite severe, since the user might have
spent some some to format the input.

Generate the ID in a loop inside, until the INSERT succeeds.
  • Loading branch information
cgzones committed Nov 11, 2024
1 parent 9d7df4f commit 301db9e
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 65 deletions.
85 changes: 59 additions & 26 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,32 +247,68 @@ impl Database {
Ok(Self { conn })
}

/// Insert `entry` under `id` into the database and optionally set owner to `uid`.
pub async fn insert(&self, id: Id, entry: write::Entry) -> Result<(), Error> {
/// Insert `entry` with a new generated `id` into the database and optionally set owner to `uid`.
pub async fn insert(&self, entry: write::Entry) -> Result<Id, Error> {
let conn = self.conn.clone();
let id = id.as_u32();
let write::DatabaseEntry { entry, data, nonce } = entry.compress().await?.encrypt().await?;

spawn_blocking(move || match entry.expires {
None => conn.lock().execute(
"INSERT INTO entries (id, uid, data, burn_after_reading, nonce) VALUES (?1, ?2, ?3, ?4, ?5)",
params![id, entry.uid, data, entry.burn_after_reading, nonce],
),
Some(expires) => conn.lock().execute(
"INSERT INTO entries (id, uid, data, burn_after_reading, nonce, expires) VALUES (?1, ?2, ?3, ?4, ?5, datetime('now', ?6))",
params![
id,
entry.uid,
data,
entry.burn_after_reading,
nonce,
format!("{expires} seconds")
],
),
let id = spawn_blocking(move || {
const COUNTER_LIMIT: u32 = 10;
let mut counter = 0;

let mut rng = rand::thread_rng();

loop {
let id: Id = rand::Rng::gen::<u32>(&mut rng).into();
let id_inner = id.as_u32();

let result = match entry.expires {
None => conn.lock().execute(
"INSERT INTO entries (id, uid, data, burn_after_reading, nonce) VALUES (?1, ?2, ?3, ?4, ?5)",
params![id_inner, entry.uid, data, entry.burn_after_reading, nonce],
),
Some(expires) => conn.lock().execute(
"INSERT INTO entries (id, uid, data, burn_after_reading, nonce, expires) VALUES (?1, ?2, ?3, ?4, ?5, datetime('now', ?6))",
params![
id_inner,
entry.uid,
data,
entry.burn_after_reading,
nonce,
format!("{expires} seconds")
],
),
};

match result {
Err(rusqlite::Error::SqliteFailure(rusqlite::ffi::Error { code, extended_code }, Some(ref _message)))
if code == rusqlite::ErrorCode::ConstraintViolation && extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_PRIMARYKEY && counter < COUNTER_LIMIT => {
/* Retry if ID is already existent */
counter += 1;
continue;
},
Err(err) => {
if counter >= COUNTER_LIMIT {
tracing::error!("Failed to generate ID after {counter} retries");
}

break Err(err)
},
Ok(rows) => {
debug_assert!(rows == 1);

if counter > 4 {
tracing::warn!("Required {counter} retries to generate new ID");
}

break Ok(id)
},
}
}
})
.await??;

Ok(())
Ok(id)
}

/// Get entire entry for `id`.
Expand Down Expand Up @@ -383,8 +419,7 @@ mod tests {
..Default::default()
};

let id = Id::from(1234);
db.insert(id, entry).await?;
let id = db.insert(entry).await?;

let entry = db.get(id, None).await?;
assert_eq!(entry.text, "hello world");
Expand All @@ -406,8 +441,7 @@ mod tests {
..Default::default()
};

let id = Id::from(1234);
db.insert(id, entry).await?;
let id = db.insert(entry).await?;

tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;

Expand All @@ -422,8 +456,7 @@ mod tests {
async fn delete() -> Result<(), Box<dyn std::error::Error>> {
let db = new_db()?;

let id = Id::from(1234);
db.insert(id, write::Entry::default()).await?;
let id = db.insert(write::Entry::default()).await?;

assert!(db.get(id, None).await.is_ok());
assert!(db.delete(id).await.is_ok());
Expand Down
8 changes: 2 additions & 6 deletions src/id.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::db::write::Entry;
use crate::errors::Error;
use std::fmt;
use std::str::FromStr;
Expand All @@ -23,11 +22,8 @@ impl Id {
}

/// Generate a URL path from the string representation and `entry`'s extension.
pub fn to_url_path(self, entry: &Entry) -> String {
entry
.extension
.as_ref()
.map_or_else(|| format!("{self}"), |ext| format!("{self}.{ext}"))
pub fn to_url_path(self, extension: Option<&str>) -> String {
extension.map_or_else(|| format!("{self}"), |ext| format!("{self}.{ext}"))
}
}

Expand Down
32 changes: 12 additions & 20 deletions src/routes/form.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ use std::num::NonZeroU32;

use crate::db::write;
use crate::env::BASE_PATH;
use crate::id::Id;
use crate::{pages, AppState, Error};
use axum::extract::{Form, State};
use axum::response::Redirect;
use axum_extra::extract::cookie::{Cookie, SameSite, SignedCookieJar};
use rand::Rng;
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -45,14 +43,6 @@ pub async fn insert(
Form(entry): Form<Entry>,
is_https: bool,
) -> Result<(SignedCookieJar, Redirect), pages::ErrorResponse<'static>> {
let id: Id = tokio::task::spawn_blocking(|| {
let mut rng = rand::thread_rng();
rng.gen::<u32>()
})
.await
.map_err(Error::from)?
.into();

// Retrieve uid from cookie or generate a new one.
let uid = if let Some(cookie) = jar.get("uid") {
cookie
Expand All @@ -66,22 +56,24 @@ pub async fn insert(
let mut entry: write::Entry = entry.into();
entry.uid = Some(uid);

let mut url = id.to_url_path(&entry);

let burn_after_reading = entry.burn_after_reading.unwrap_or(false);
if burn_after_reading {
url = format!("burn/{url}");
}

let url_with_base = BASE_PATH.join(&url);

if let Some(max_exp) = state.max_expiration {
entry.expires = entry
.expires
.map_or_else(|| Some(max_exp), |value| Some(value.min(max_exp)));
}

state.db.insert(id, entry).await?;
let burn = entry.burn_after_reading.unwrap_or(false);
let extension = entry.extension.clone();

let id = state.db.insert(entry).await?;

let mut url = id.to_url_path(extension.as_deref());

if burn {
url = format!("burn/{url}");
}

let url_with_base = BASE_PATH.join(&url);

let cookie = Cookie::build(("uid", uid.to_string()))
.http_only(true)
Expand Down
18 changes: 5 additions & 13 deletions src/routes/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ use std::num::NonZeroU32;

use crate::db::write;
use crate::env::BASE_PATH;
use crate::errors::{Error, JsonErrorResponse};
use crate::id::Id;
use crate::errors::JsonErrorResponse;
use crate::AppState;
use axum::extract::State;
use axum::Json;
use rand::Rng;
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -41,14 +39,6 @@ pub async fn insert(
state: State<AppState>,
Json(entry): Json<Entry>,
) -> Result<Json<RedirectResponse>, JsonErrorResponse> {
let id: Id = tokio::task::spawn_blocking(|| {
let mut rng = rand::thread_rng();
rng.gen::<u32>()
})
.await
.map_err(Error::from)?
.into();

let mut entry: write::Entry = entry.into();

if let Some(max_exp) = state.max_expiration {
Expand All @@ -57,9 +47,11 @@ pub async fn insert(
.map_or_else(|| Some(max_exp), |value| Some(value.min(max_exp)));
}

let url = id.to_url_path(&entry);
let extension = entry.extension.clone();

let id = state.db.insert(entry).await?;
let url = id.to_url_path(extension.as_deref());
let path = BASE_PATH.join(&url);
state.db.insert(id, entry).await?;

Ok(Json::from(RedirectResponse { path }))
}

0 comments on commit 301db9e

Please sign in to comment.