Skip to content

Commit

Permalink
Add ormlite to simplify sqlx queries
Browse files Browse the repository at this point in the history
  • Loading branch information
SpartanPlume committed Sep 1, 2024
1 parent 1a2f6b0 commit 7003202
Show file tree
Hide file tree
Showing 11 changed files with 835 additions and 538 deletions.
1,202 changes: 725 additions & 477 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion crates/libs/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ chrono = { version = "0.4", default-features = false, features = [
lazy_static = "1.4"
nutype = { version = "0.4", features = ["regex", "serde"] }
nutype_macros = { version = "0.4", features = ["regex"] }
ormlite = { version = "0.19", features = ["postgres", "chrono"] }
regex = "1.10"
serde = { version = "1", features = ["derive"] }
sqlx = { version = "0.7", features = [
serde_json = "1.0"
sqlx = { version = "0.8", features = [
"runtime-tokio",
"postgres",
"macros",
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
DROP DATABASE IF EXISTS tournaments
DROP DATABASE IF EXISTS tournament
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Add migration script here
CREATE TABLE IF NOT EXISTS tournaments(
CREATE TABLE IF NOT EXISTS tournament(
id SERIAL NOT NULL,
name TEXT NOT NULL UNIQUE,
acronym TEXT NOT NULL,
Expand Down
46 changes: 40 additions & 6 deletions crates/libs/core/src/domain/tournament/acronym.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,44 @@ use nutype::nutype;
#[nutype(
sanitize(trim),
validate(not_empty, len_char_max = 16, regex = r#"^[\w!#?@]+$"#),
derive(Debug, Deserialize, Clone, AsRef)
derive(Debug, Serialize, Deserialize, Clone, AsRef)
)]
pub struct TournamentAcronym(String);

use ormlite::postgres::PgArgumentBuffer;
use ormlite::postgres::PgValueRef;
use sqlx::Postgres;

impl sqlx::Type<Postgres> for TournamentAcronym {
fn type_info() -> <Postgres as sqlx::Database>::TypeInfo {
<String as sqlx::Type<Postgres>>::type_info()
}

fn compatible(ty: &<Postgres as sqlx::Database>::TypeInfo) -> bool {
<String as sqlx::Type<Postgres>>::compatible(ty)
}
}

impl sqlx::Encode<'_, Postgres> for TournamentAcronym {
fn encode_by_ref(
&self,
buf: &mut PgArgumentBuffer,
) -> Result<sqlx::encode::IsNull, Box<(dyn std::error::Error + Send + Sync + 'static)>> {
let s = serde_json::to_value(self)?;
let s = s.as_str().unwrap();
<&'_ str as sqlx::Encode<Postgres>>::encode(s, buf)
}
}

impl sqlx::Decode<'_, Postgres> for TournamentAcronym {
fn decode(value: PgValueRef<'_>) -> anyhow::Result<Self, sqlx::error::BoxDynError> {
let value = value.as_str()?;
let value = serde_json::Value::String(value.to_string());
let value = serde_json::from_value(value)?;
Ok(value)
}
}

#[cfg(test)]
mod tests {
use super::TournamentAcronym;
Expand All @@ -17,31 +51,31 @@ mod tests {
#[test]
fn empty_acronym_is_invalid() {
let acronym = "".to_string();
assert_err!(TournamentAcronym::new(acronym));
assert_err!(TournamentAcronym::try_new(acronym));
}

#[test]
fn acronym_with_only_whitespaces_is_invalid() {
let acronym = " ".to_string();
assert_err!(TournamentAcronym::new(acronym));
assert_err!(TournamentAcronym::try_new(acronym));
}

#[test]
fn acronym_longer_than_16_is_invalid() {
let acronym = "a".repeat(17);
assert_err!(TournamentAcronym::new(acronym));
assert_err!(TournamentAcronym::try_new(acronym));
}

#[test]
fn acronym_containing_invalid_character_is_invalid() {
let acronym = "/".to_string();
assert_err!(TournamentAcronym::new(acronym));
assert_err!(TournamentAcronym::try_new(acronym));
}

proptest! {
#[test]
fn acronym_containing_valid_characters_is_valid(acronym in "[a-z][A-Z][0-9]!#?@_") {
assert_ok!(TournamentAcronym::new(acronym));
assert_ok!(TournamentAcronym::try_new(acronym));
}
}
}
48 changes: 8 additions & 40 deletions crates/libs/core/src/domain/tournament/mod.rs
Original file line number Diff line number Diff line change
@@ -1,59 +1,27 @@
mod acronym;
mod name;

use anyhow::Context;
use serde::{Deserialize, Serialize};

use crate::DbPool;
use acronym::TournamentAcronym;
use name::TournamentName;

#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
use ormlite::model::*;

#[derive(Model, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Tournament {
pub id: i32,
pub name: String,
pub acronym: String,
#[ormlite(default)]
pub created_at: chrono::DateTime<chrono::Utc>,
#[ormlite(default)]
pub updated_at: chrono::DateTime<chrono::Utc>,
}

#[derive(Debug, Deserialize, Clone)]
pub struct NewTournament {
#[derive(Insert, Debug, Deserialize, Clone)]
#[ormlite(returns = "Tournament")]
pub struct InsertTournament {
pub name: TournamentName,
pub acronym: TournamentAcronym,
}

impl Tournament {
pub async fn get_all(db_pool: &DbPool) -> Result<Vec<Tournament>, anyhow::Error> {
sqlx::query_as!(Tournament, "SELECT * FROM tournaments")
.fetch_all(db_pool)
.await
.context("Could not retrieve tournaments from database")
}

pub async fn get_by_id(db_pool: &DbPool, entry_id: i32) -> Result<Tournament, anyhow::Error> {
sqlx::query_as!(
Tournament,
"SELECT * FROM tournaments WHERE id = $1",
entry_id
)
.fetch_one(db_pool)
.await
.context("Could not retrieve tournament by id from database")
}
}

impl NewTournament {
pub async fn insert(&self, db_pool: &DbPool) -> Result<Tournament, anyhow::Error> {
let tournament = sqlx::query_as!(
Tournament,
"INSERT INTO tournaments (name, acronym) VALUES ($1, $2) RETURNING *",
self.name.as_ref(),
self.acronym.as_ref()
)
.fetch_one(db_pool)
.await
.context("Could not retrieve tournament by id from database")?;
Ok(tournament)
}
}
46 changes: 40 additions & 6 deletions crates/libs/core/src/domain/tournament/name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,44 @@ use nutype::nutype;
#[nutype(
sanitize(trim),
validate(not_empty, len_char_max = 128, regex = r#"^[^"\\]+$"#),
derive(Debug, Deserialize, Clone, AsRef)
derive(Debug, Serialize, Deserialize, Clone, AsRef)
)]
pub struct TournamentName(String);

use ormlite::postgres::PgArgumentBuffer;
use ormlite::postgres::PgValueRef;
use sqlx::Postgres;

impl sqlx::Type<Postgres> for TournamentName {
fn type_info() -> <Postgres as sqlx::Database>::TypeInfo {
<String as sqlx::Type<Postgres>>::type_info()
}

fn compatible(ty: &<Postgres as sqlx::Database>::TypeInfo) -> bool {
<String as sqlx::Type<Postgres>>::compatible(ty)
}
}

impl sqlx::Encode<'_, Postgres> for TournamentName {
fn encode_by_ref(
&self,
buf: &mut PgArgumentBuffer,
) -> Result<sqlx::encode::IsNull, Box<(dyn std::error::Error + Send + Sync + 'static)>> {
let s = serde_json::to_value(self)?;
let s = s.as_str().unwrap();
<&'_ str as sqlx::Encode<Postgres>>::encode(s, buf)
}
}

impl sqlx::Decode<'_, Postgres> for TournamentName {
fn decode(value: PgValueRef<'_>) -> anyhow::Result<Self, sqlx::error::BoxDynError> {
let value = value.as_str()?;
let value = serde_json::Value::String(value.to_string());
let value = serde_json::from_value(value)?;
Ok(value)
}
}

#[cfg(test)]
mod tests {
use super::TournamentName;
Expand All @@ -16,32 +50,32 @@ mod tests {
#[test]
fn empty_name_is_invalid() {
let name = "".to_string();
assert_err!(TournamentName::new(name));
assert_err!(TournamentName::try_new(name));
}

#[test]
fn name_with_only_whitespaces_is_invalid() {
let name = " ".to_string();
assert_err!(TournamentName::new(name));
assert_err!(TournamentName::try_new(name));
}

#[test]
fn name_longer_than_128_is_invalid() {
let name = "a".repeat(129);
assert_err!(TournamentName::new(name));
assert_err!(TournamentName::try_new(name));
}

#[test]
fn name_containing_invalid_characters_is_invalid() {
for name in &['\\', '"'] {
let name = name.to_string();
assert_err!(TournamentName::new(name));
assert_err!(TournamentName::try_new(name));
}
}

#[test]
fn name_containing_valid_characters_is_valid() {
let name = "Tournament name".to_string();
assert_ok!(TournamentName::new(name));
assert_ok!(TournamentName::try_new(name));
}
}
3 changes: 2 additions & 1 deletion crates/services/web-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ chrono = { version = "0.4", default-features = false, features = [
"clock",
"serde",
] }
ormlite = "0.19"
secrecy = { version = "0.8", features = ["serde"] }
serde = { version = "1", features = ["derive"] }
serde-aux = "4"
Expand All @@ -37,4 +38,4 @@ lazy_static = "1.4"
once_cell = "1"
reqwest = { version = "0.11", features = ["json"] }
uuid = { version = "1", features = ["v4"] }
sqlx = { version = "0.7", features = ["runtime-tokio", "postgres", "macros"] }
sqlx = { version = "0.8", features = ["runtime-tokio", "postgres"] }
6 changes: 6 additions & 0 deletions crates/services/web-server/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use serde::Serialize;
pub enum Error {
#[error(transparent)]
ServerError(#[from] anyhow::Error),
#[error(transparent)]
DatabaseError(#[from] ormlite::Error),
}

impl std::fmt::Debug for Error {
Expand Down Expand Up @@ -42,6 +44,10 @@ impl IntoResponse for Error {
StatusCode::INTERNAL_SERVER_ERROR,
"An unexpected error occurred".to_string(),
),
Error::DatabaseError(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"An unexpected database error occurred".to_string(),
),
};

(status, Json(ErrorResponse { message })).into_response()
Expand Down
8 changes: 4 additions & 4 deletions crates/services/web-server/src/routes/tournaments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@ use axum::Json;
use crate::prelude::*;
use tosurnament_core::domain::tournament::*;

#[allow(clippy::async_yields_async)]
use ormlite::model::*;

#[tracing::instrument(skip_all, fields(?body_data.name))]
pub async fn create_tournament(
State(context): State<Context>,
Json(body_data): Json<NewTournament>,
Json(body_data): Json<InsertTournament>,
) -> Result<(StatusCode, Json<Tournament>)> {
let result = body_data.insert(&context.db.pool).await?;
Ok((StatusCode::CREATED, Json(result)))
}

#[allow(clippy::async_yields_async)]
#[tracing::instrument(skip_all)]
pub async fn get_tournaments(State(context): State<Context>) -> Result<Json<Vec<Tournament>>> {
let results = Tournament::get_all(&context.db.pool).await?;
let results = Tournament::select().fetch_all(&context.db.pool).await?;
Ok(Json(results))
}
6 changes: 5 additions & 1 deletion crates/services/web-server/tests/api/tournaments.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use ormlite::model::*;
use std::collections::HashMap;

use tosurnament_core::domain::tournament::*;
Expand All @@ -20,7 +21,10 @@ async fn create_tournament_returns_201_for_valid_data() {
.expect("Invalid tournament object returned by the API");
assert_eq!(created.name, "Tournament name");
assert_eq!(created.acronym, "TN");
let saved = Tournament::get_by_id(&app.context.db.pool, created.id)
let saved = Tournament::select()
.where_("id = ?")
.bind(created.id)
.fetch_one(&app.context.db.pool)
.await
.expect("Could not retrieve tournament from db");
assert_eq!(created, saved);
Expand Down

0 comments on commit 7003202

Please sign in to comment.