diff --git a/src/bin/main.rs b/src/bin/main.rs index 731cad8e..ef2c54b3 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -136,6 +136,10 @@ pub struct MainOptions { /// See `crate::GenerationConfig::diesel_backend` for more details. #[arg(short = 'b', long = "diesel-backend")] pub diesel_backend: String, + + /// Generate the "default" function in an `impl Default` + #[arg(long)] + pub default_impl: bool, } #[derive(Debug, ValueEnum, Clone, PartialEq, Default)] @@ -265,6 +269,7 @@ fn actual_main() -> dsync::Result<()> { once_connection_type: args.once_connection_type, readonly_prefixes: args.readonly_prefixes, readonly_suffixes: args.readonly_suffixes, + default_impl: args.default_impl, }, }, )?; diff --git a/src/code.rs b/src/code.rs index 44629b74..568848c3 100644 --- a/src/code.rs +++ b/src/code.rs @@ -1,6 +1,5 @@ use heck::ToPascalCase; use indoc::formatdoc; -use std::borrow::Cow; use crate::parser::{ParsedColumnMacro, ParsedTableMacro, FILE_SIGNATURE}; use crate::{get_table_module_name, GenerationConfig, TableOptions}; @@ -86,7 +85,7 @@ pub struct StructField { impl StructField { /// Assemble the current options into a rust type, like `base_type: String, is_optional: true` to `Option` - pub fn to_rust_type(&self) -> Cow<'_, str> { + pub fn to_rust_type(&self) -> std::borrow::Cow<'_, str> { let mut rust_type = self.base_type.clone(); // order matters! @@ -209,6 +208,7 @@ impl<'a> Struct<'a> { derives::SELECTABLE, #[cfg(feature = "derive-queryablebyname")] derives::QUERYABLEBYNAME, + derives::PARTIALEQ, ]); if !self.table.foreign_keys.is_empty() { @@ -228,7 +228,9 @@ impl<'a> Struct<'a> { derives_vec.push(derives::PARTIALEQ); } - derives_vec.push(derives::DEFAULT); + if !self.config.options.default_impl { + derives_vec.push(derives::DEFAULT); + } } StructType::Create => derives_vec.extend_from_slice(&[derives::INSERTABLE]), } @@ -297,7 +299,7 @@ impl<'a> Struct<'a> { .collect::>() .join(" "); - let fields = self.fields(); + let mut fields = self.fields(); if fields.is_empty() { self.has_fields = Some(false); @@ -330,18 +332,18 @@ impl<'a> Struct<'a> { }; let mut lines = Vec::with_capacity(fields.len()); - for mut f in fields.into_iter() { + for f in fields.iter_mut() { let field_name = &f.name; if f.base_type == "String" { f.base_type = match self.ty { - StructType::Read => f.base_type, + StructType::Read => f.base_type.clone(), StructType::Update => self.opts.get_update_str_type().as_str().to_string(), StructType::Create => self.opts.get_create_str_type().as_str().to_string(), } } else if f.base_type == "Vec" { f.base_type = match self.ty { - StructType::Read => f.base_type, + StructType::Read => f.base_type.clone(), StructType::Update => self.opts.get_update_bytes_type().as_str().to_string(), StructType::Create => self.opts.get_create_bytes_type().as_str().to_string(), } @@ -380,7 +382,7 @@ impl<'a> Struct<'a> { ), }; - let struct_code = formatdoc!( + let mut struct_code = formatdoc!( r#" {doccomment} {tsync_attr}{derive_attr} @@ -407,6 +409,15 @@ impl<'a> Struct<'a> { lines = lines.join("\n"), ); + if self.config.options.default_impl { + struct_code.push('\n'); + struct_code.push_str(&build_default_impl_fn( + self.ty, + &ty.format(&table.struct_name), + &fields, + )); + } + self.has_fields = Some(true); self.rendered_code = Some(struct_code); } @@ -761,10 +772,108 @@ fn build_imports(table: &ParsedTableMacro, config: &GenerationConfig) -> String imports_vec.join("\n") } +/// Get default for type +fn default_for_type(typ: &str) -> &'static str { + match typ { + "i8" | "u8" | "i16" | "u16" | "i32" | "u32" | "i64" | "u64" | "i128" | "u128" | "isize" + | "usize" => "0", + "f32" | "f64" => "0.0", + // https://doc.rust-lang.org/std/primitive.bool.html#method.default + "bool" => "false", + "String" => "String::new()", + "&str" | "&'static str" => "\"\"", + "Cow" => "Cow::Owned(String::new())", + _ => { + if typ.starts_with("Option<") { + "None" + } else { + "Default::default()" + } + } + } +} + +/// Generate default (insides of the `impl Default for StructName { fn default() -> Self {} }`) +fn build_default_impl_fn<'a>( + struct_type: StructType, + struct_name: &str, + fields: &[StructField], +) -> String { + let fields: Vec = fields + .iter() + .map(|name_typ_nullable| { + format!( + "{name}: {typ_default},", + name = name_typ_nullable.name, + typ_default = if name_typ_nullable.is_optional || struct_type == StructType::Update + { + "None" + } else { + default_for_type(&name_typ_nullable.base_type) + } + ) + }) + .collect(); + formatdoc!( + r#" + impl Default for {struct_name} {{ + fn default() -> Self {{ + Self {{ + {fields} + }} + }} + }} + "#, + fields = fields.join("\n ") + ) +} + +#[test] +fn test_build_default_impl_fn() { + let fields = vec![ + StructField { + name: String::from("id"), + column_name: String::from("id"), + base_type: String::from("i32"), + is_optional: false, + is_vec: false, + }, + StructField { + name: String::from("title"), + column_name: String::from("title"), + base_type: String::from("String"), + is_optional: false, + is_vec: false, + }, + StructField { + name: String::from("maybe_value"), + column_name: String::from("maybe_value"), + base_type: String::from("i64"), + is_optional: true, + is_vec: false, + }, + ]; + + let generated_code = build_default_impl_fn(StructType::Create, "CreateFake", &fields); + + let expected = r#"impl Default for CreateFake { + fn default() -> Self { + Self { + id: 0, + title: String::new(), + maybe_value: None, + } + } +} +"#; + + assert_eq!(&generated_code, &expected); +} + /// Generate a full file for a given diesel table pub fn generate_for_table(table: &ParsedTableMacro, config: &GenerationConfig) -> String { // early to ensure the table options are set for the current table - let table_options = config.table(&table.name.to_string()); + let table_options = config.table(table.name.to_string().as_str()); let mut ret_buffer = format!("{FILE_SIGNATURE}\n\n"); @@ -789,7 +898,7 @@ pub fn generate_for_table(table: &ParsedTableMacro, config: &GenerationConfig) - ret_buffer.push_str(update_struct.code()); } - // third and lastly, push functions - if enabled + // third, push functions - if enabled if table_options.get_fns() { ret_buffer.push('\n'); ret_buffer.push_str(build_table_fns(table, config, create_struct, update_struct).as_str()); diff --git a/src/global.rs b/src/global.rs index ad1eaa0c..327a1e23 100644 --- a/src/global.rs +++ b/src/global.rs @@ -328,6 +328,8 @@ pub struct GenerationConfigOpts<'a> { pub readonly_prefixes: Vec, /// Suffixes to treat tables as readonly pub readonly_suffixes: Vec, + /// Generate the "default" function in an `impl Default` + pub default_impl: bool, } impl GenerationConfigOpts<'_> { @@ -363,6 +365,7 @@ impl Default for GenerationConfigOpts<'_> { once_connection_type: false, readonly_prefixes: Vec::default(), readonly_suffixes: Vec::default(), + default_impl: false, } } } diff --git a/test/default_impl/Cargo.toml b/test/default_impl/Cargo.toml new file mode 100644 index 00000000..c8b0b907 --- /dev/null +++ b/test/default_impl/Cargo.toml @@ -0,0 +1,18 @@ +[lib] +path = "lib.rs" + +[package] +name = "default_impl" +version = "0.1.0" +edition = "2021" + +[dependencies] +diesel = { version = "*", default-features = false, features = [ + "sqlite", + "r2d2", + "chrono", + "returning_clauses_for_sqlite_3_35", +] } +r2d2.workspace = true +chrono.workspace = true +serde.workspace = true diff --git a/test/default_impl/lib.rs b/test/default_impl/lib.rs new file mode 100644 index 00000000..fdea3f51 --- /dev/null +++ b/test/default_impl/lib.rs @@ -0,0 +1,6 @@ +pub mod models; +pub mod schema; + +pub mod diesel { + pub use diesel::*; +} diff --git a/test/default_impl/models/mod.rs b/test/default_impl/models/mod.rs new file mode 100644 index 00000000..015a6a2b --- /dev/null +++ b/test/default_impl/models/mod.rs @@ -0,0 +1 @@ +pub mod todos; diff --git a/test/default_impl/models/todos/generated.rs b/test/default_impl/models/todos/generated.rs new file mode 100644 index 00000000..aac39819 --- /dev/null +++ b/test/default_impl/models/todos/generated.rs @@ -0,0 +1,151 @@ +/* @generated and managed by dsync */ + +#[allow(unused)] +use crate::diesel::*; +use crate::schema::*; + +pub type ConnectionType = diesel::r2d2::PooledConnection>; + +/// Struct representing a row in table `todos` +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, diesel::Queryable, diesel::Selectable, diesel::QueryableByName, PartialEq, diesel::Identifiable)] +#[diesel(table_name=todos, primary_key(id))] +pub struct Todos { + /// Field representing column `id` + pub id: i32, + /// Field representing column `text` + pub text: String, + /// Field representing column `completed` + pub completed: bool, + /// Field representing column `type` + pub type_: String, + /// Field representing column `smallint` + pub smallint: i16, + /// Field representing column `bigint` + pub bigint: i64, + /// Field representing column `created_at` + pub created_at: chrono::NaiveDateTime, + /// Field representing column `updated_at` + pub updated_at: chrono::NaiveDateTime, +} + +impl Default for Todos { + fn default() -> Self { + Self { + id: 0, + text: String::new(), + completed: false, + type_: String::new(), + smallint: 0, + bigint: 0, + created_at: Default::default(), + updated_at: Default::default(), + } + } +} + +/// Create Struct for a row in table `todos` for [`Todos`] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, diesel::Insertable)] +#[diesel(table_name=todos)] +pub struct CreateTodos { + /// Field representing column `text` + pub text: String, + /// Field representing column `completed` + pub completed: bool, + /// Field representing column `type` + pub type_: String, + /// Field representing column `smallint` + pub smallint: i16, + /// Field representing column `bigint` + pub bigint: i64, +} + +impl Default for CreateTodos { + fn default() -> Self { + Self { + text: String::new(), + completed: false, + type_: String::new(), + smallint: 0, + bigint: 0, + } + } +} + +/// Update Struct for a row in table `todos` for [`Todos`] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, diesel::AsChangeset, PartialEq)] +#[diesel(table_name=todos)] +pub struct UpdateTodos { + /// Field representing column `text` + pub text: Option, + /// Field representing column `completed` + pub completed: Option, + /// Field representing column `type` + pub type_: Option, + /// Field representing column `smallint` + pub smallint: Option, + /// Field representing column `bigint` + pub bigint: Option, + /// Field representing column `created_at` + pub created_at: Option, + /// Field representing column `updated_at` + pub updated_at: Option, +} + +impl Default for UpdateTodos { + fn default() -> Self { + Self { + text: String::new(), + completed: false, + type_: String::new(), + smallint: 0, + bigint: 0, + created_at: Default::default(), + updated_at: Default::default(), + } + } +} + +/// Result of a `.paginate` function +#[derive(Debug, serde::Serialize)] +pub struct PaginationResult { + /// Resulting items that are from the current page + pub items: Vec, + /// The count of total items there are + pub total_items: i64, + /// Current page, 0-based index + pub page: i64, + /// Size of a page + pub page_size: i64, + /// Number of total possible pages, given the `page_size` and `total_items` + pub num_pages: i64, +} + +impl Todos { + /// Insert a new row into `todos` with a given [`CreateTodos`] + pub fn create(db: &mut ConnectionType, item: &CreateTodos) -> diesel::QueryResult { + use crate::schema::todos::dsl::*; + + diesel::insert_into(todos).values(item).get_result::(db) + } + + /// Get a row from `todos`, identified by the primary key + pub fn read(db: &mut ConnectionType, param_id: i32) -> diesel::QueryResult { + use crate::schema::todos::dsl::*; + + todos.filter(id.eq(param_id)).first::(db) + } + + /// Update a row in `todos`, identified by the primary key with [`UpdateTodos`] + pub fn update(db: &mut ConnectionType, param_id: i32, item: &UpdateTodos) -> diesel::QueryResult { + use crate::schema::todos::dsl::*; + + diesel::update(todos.filter(id.eq(param_id))).set(item).get_result(db) + } + + /// Delete a row in `todos`, identified by the primary key + pub fn delete(db: &mut ConnectionType, param_id: i32) -> diesel::QueryResult { + use crate::schema::todos::dsl::*; + + diesel::delete(todos.filter(id.eq(param_id))).execute(db) + } +} diff --git a/test/default_impl/models/todos/mod.rs b/test/default_impl/models/todos/mod.rs new file mode 100644 index 00000000..a5bb9b90 --- /dev/null +++ b/test/default_impl/models/todos/mod.rs @@ -0,0 +1,2 @@ +pub use generated::*; +pub mod generated; diff --git a/test/default_impl/schema.rs b/test/default_impl/schema.rs new file mode 100644 index 00000000..c91a5c4f --- /dev/null +++ b/test/default_impl/schema.rs @@ -0,0 +1,16 @@ +diesel::table! { + todos (id) { + id -> Int4, + // unsigned -> Unsigned, + // unsigned_nullable -> Nullable>, + text -> Text, + completed -> Bool, + #[sql_name = "type"] + #[max_length = 255] + type_ -> Varchar, + smallint -> Int2, + bigint -> Int8, + created_at -> Timestamp, + updated_at -> Timestamp, + } +} diff --git a/test/default_impl/test.sh b/test/default_impl/test.sh new file mode 100755 index 00000000..7c33aaf3 --- /dev/null +++ b/test/default_impl/test.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" + +cd $SCRIPT_DIR + +cargo run --manifest-path ../../Cargo.toml -- \ +-i schema.rs -o models -g id -g created_at -g updated_at --default-impl -c "diesel::r2d2::PooledConnection>"