Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
62f2720
simplify TempDatabase creation
pedrocarlo Nov 18, 2025
696bf22
create a basic test proc macro that takes a db path and an mvcc flag
pedrocarlo Nov 18, 2025
237ce5b
enhance proc macro to detect the fn args and the return type with bet…
pedrocarlo Nov 18, 2025
19bb802
create two tests when mvcc flag is passed
pedrocarlo Nov 19, 2025
1e85706
pass the fn arg type to the macro closure to avoid errors
pedrocarlo Nov 19, 2025
91e1fc4
create temp database builder to facilitate passing options to proc macro
pedrocarlo Nov 19, 2025
0ffda31
add init_sql arg to proc macro
pedrocarlo Nov 19, 2025
463ca00
migrate test cases to use the macro
pedrocarlo Nov 19, 2025
39ffa76
change `test_write_path` to use new test macros
pedrocarlo Nov 19, 2025
f3854c2
change `test_transactions` to use new test macros
pedrocarlo Nov 19, 2025
5301a04
change `test_read_path` to use new test macros
pedrocarlo Nov 19, 2025
c58cd32
change `test_multi_thread` to use new test macros
pedrocarlo Nov 19, 2025
7895326
change `test_ddl` to use new test macros
pedrocarlo Nov 19, 2025
3bc0499
change `test_btree` to use new test macros
pedrocarlo Nov 19, 2025
4a79d4f
change `encryption` to use new test macros
pedrocarlo Nov 19, 2025
479048e
change `index_methods` to use new test macros
pedrocarlo Nov 19, 2025
c4653ae
change `functions` to use new test macros
pedrocarlo Nov 19, 2025
5aed514
change `trigger` test to use new test macros
pedrocarlo Nov 19, 2025
8b2caab
change `pragma` test to use new test macros
pedrocarlo Nov 19, 2025
963b37f
change `fuzz` test to use new test macros
pedrocarlo Nov 19, 2025
99f8952
add docs to macro
pedrocarlo Nov 20, 2025
a3acec5
clippy
pedrocarlo Nov 20, 2025
2ef2b13
pending byte page cannot run at the same time with mvcc test
pedrocarlo Nov 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 32 additions & 1 deletion macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod ext;
extern crate proc_macro;
mod atomic_enum;
mod ext;
mod test;

use proc_macro::{token_stream::IntoIter, Group, TokenStream, TokenTree};
use std::collections::HashMap;

Expand Down Expand Up @@ -492,3 +494,32 @@ pub fn match_ignore_ascii_case(input: TokenStream) -> TokenStream {
pub fn derive_atomic_enum(input: TokenStream) -> TokenStream {
atomic_enum::derive_atomic_enum_inner(input)
}

/// Test macro for `core_tester` crate
///
/// Generates a runnable Rust test from the following function signature
///
/// ```no_run
/// fn test_x(db: TempDatabase) -> Result<()> {}
/// // Or
/// fn test_y(db: TempDatabase) {}
/// ```
///
/// Macro accepts the following arguments
///
/// - `mvcc` flag: creates an additional test that will run the same code with MVCC enabled
/// - `path` arg: specifies the name of the database to be created
/// - `init_sql` arg: specifies the SQL query that will be run by `rusqlite` before initializing the Turso database
///
/// Example:
/// ```no_run,rust
/// #[turso_macros::test(mvcc, path = "test.db", init_sql = "CREATE TABLE test_rowid (id INTEGER PRIMARY KEY);")]
/// fn test_integer_primary_key(tmp_db: TempDatabase) -> anyhow::Result<()> {
/// // Code goes here to test
/// Ok(())
/// }
/// ```
#[proc_macro_attribute]
pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
test::test_macro_attribute(args, input)
}
318 changes: 318 additions & 0 deletions macros/src/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
use std::{collections::HashSet, ops::Deref};

use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{quote, quote_spanned, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
spanned::Spanned,
Expr, Ident, ItemFn, Meta, Pat, ReturnType, Token, Type,
};

#[derive(Debug, Clone, Copy)]
struct SpannedType<T>(T, Span);

impl<T> SpannedType<T> {
fn map<U>(self, func: impl FnOnce(T) -> U) -> SpannedType<U> {
SpannedType(func(self.0), self.1)
}
}

impl<T> Deref for SpannedType<T> {
type Target = T;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<T: ToTokens> ToTokens for SpannedType<T> {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let span = self.1;
let val = &self.0;
let out_tokens = quote_spanned! {span=>
#val
};
out_tokens.to_tokens(tokens);
}
}

#[derive(Debug)]
struct Args {
path: Option<SpannedType<String>>,
mvcc: Option<SpannedType<()>>,
init_sql: Option<Expr>,
}

impl Args {
fn get_tmp_db_builder(
&self,
fn_name: &Ident,
tmp_db_ty: &Type,
mvcc: bool,
) -> proc_macro2::TokenStream {
let mut builder = quote! {#tmp_db_ty::builder()};

let db_name = self.path.clone().map_or_else(
|| {
let name = format!("{fn_name}.db");
quote! {#name}
},
|path| path.to_token_stream(),
);

let mut db_opts = quote! {
turso_core::DatabaseOpts::new()
.with_indexes(true)
.with_index_method(true)
.with_encryption(true)
};

if let Some(spanned) = self
.mvcc
.filter(|_| mvcc)
.map(|val| val.map(|_| quote! {.with_mvcc(true)}))
{
db_opts = quote! {
#db_opts
#spanned
}
}

builder = quote! {
#builder
.with_db_name(#db_name)
.with_opts(#db_opts)
};

if let Some(expr) = &self.init_sql {
builder = quote! {
#builder
.with_init_sql(#expr)
};
}

quote! {
#builder.build()
}
}
}

impl Parse for Args {
fn parse(input: ParseStream) -> syn::Result<Self> {
let args = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
let mut seen_args = HashSet::new();

let mut path = None;
let mut mvcc = None;
let mut init_sql = None;

let errors = args
.into_iter()
.filter_map(|meta| {
match meta {
Meta::NameValue(nv) => {
let ident = nv.path.get_ident();
if let Some(ident) = ident {
let ident_string = ident.to_string();
match ident_string.as_str() {
"path" => {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = &nv.value
{
path = Some(SpannedType(lit_str.value(), nv.value.span()));
seen_args.insert(ident.clone());
} else {
return Some(syn::Error::new_spanned(
nv.value,
"argument is not a string literal",
));
}
}
"init_sql" => {
init_sql = Some(nv.value.clone());
}
_ => {
return Some(syn::Error::new_spanned(
nv.path,
"unexpected argument",
))
}
}
} else {
return Some(syn::Error::new_spanned(nv.path, "unexpected argument"));
}
}
Meta::Path(p) => {
let ident = p.get_ident();
if p.is_ident("mvcc") {
mvcc = Some(SpannedType((), p.span()));
seen_args.insert(ident.unwrap().clone());
} else {
return Some(syn::Error::new_spanned(p, "unexpected flag"));
}
}
_ => {
return Some(syn::Error::new_spanned(meta, "unexpected argument format"));
}
};
None
})
.reduce(|mut accum, err| {
accum.combine(err);
accum
});

if let Some(errors) = errors {
return Err(errors);
}

Ok(Args {
path,
mvcc,
init_sql,
})
}
}

struct DatabaseFunction {
input: ItemFn,
tmp_db_fn_arg: (Pat, syn::Type),
args: Args,
}

impl DatabaseFunction {
fn new(input: ItemFn, tmp_db_fn_arg: (Pat, syn::Type), args: Args) -> Self {
Self {
input,
tmp_db_fn_arg,
args,
}
}

fn tokens_for_db_type(&self, mvcc: bool) -> proc_macro2::TokenStream {
let ItemFn {
attrs,
vis,
sig,
block,
} = &self.input;

let fn_name = if mvcc {
Ident::new(&format!("{}_mvcc", sig.ident), sig.ident.span())
} else {
sig.ident.clone()
};
let fn_generics = &sig.generics;

// Check the return type
let is_result = is_result(&sig.output);

let (arg_name, arg_ty) = &self.tmp_db_fn_arg;
let fn_out = &sig.output;

let call_func = if is_result {
quote! {(|#arg_name: #arg_ty|#fn_out #block)(#arg_name).unwrap();}
} else {
quote! {(|#arg_name: #arg_ty| #block)(#arg_name);}
};

let tmp_db_builder_args = self.args.get_tmp_db_builder(&fn_name, arg_ty, mvcc);

quote! {
#[test]
#(#attrs)*
#vis fn #fn_name #fn_generics() {
let #arg_name = #tmp_db_builder_args;

#call_func
}

}
}
}

impl ToTokens for DatabaseFunction {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let out = self.tokens_for_db_type(false);
out.to_tokens(tokens);
if self.args.mvcc.is_some() {
let out = self.tokens_for_db_type(true);
out.to_tokens(tokens);
}
}
}

pub fn test_macro_attribute(args: TokenStream, input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as ItemFn);

let args = parse_macro_input!(args as Args);

let tmp_db_arg = match check_fn_inputs(&input) {
Ok(fn_arg) => fn_arg,
Err(err) => return err.into_compile_error().into(),
};

let db_function = DatabaseFunction::new(input, tmp_db_arg, args);

db_function.to_token_stream().into()
}

fn check_fn_inputs(input: &ItemFn) -> syn::Result<(Pat, syn::Type)> {
let msg = "Only 1 function argument can be passed and it must be of type `TempDatabase`";
let args = &input.sig.inputs;
if args.len() != 1 {
return Err(syn::Error::new_spanned(&input.sig, msg));
}
let first = args.first().unwrap();
match first {
syn::FnArg::Receiver(receiver) => Err(syn::Error::new_spanned(receiver, msg)),
syn::FnArg::Typed(pat_type) => {
if let Type::Path(type_path) = pat_type.ty.as_ref() {
// Check if qself is None (not a qualified path like <T as Trait>::Type)
if type_path.qself.is_some() {
return Err(syn::Error::new_spanned(type_path, msg));
}

// Get the last segment of the path
// This works for both:
// - Simple: TempDatabase
// - Qualified: crate::TempDatabase, my_module::TempDatabase
if type_path
.path
.segments
.last()
.is_none_or(|segment| segment.ident != "TempDatabase")
{
return Err(syn::Error::new_spanned(type_path, msg));
}
Ok((*pat_type.pat.clone(), *pat_type.ty.clone()))
} else {
Err(syn::Error::new_spanned(pat_type, msg))
}
}
}
}

fn is_result(return_type: &ReturnType) -> bool {
match return_type {
ReturnType::Default => false, // Returns ()
ReturnType::Type(_, ty) => {
// Check if the type path contains "Result"
if let syn::Type::Path(type_path) = ty.as_ref() {
type_path
.path
.segments
.last()
.map(|seg| seg.ident == "Result")
.unwrap_or(false)
} else {
false
}
}
}
}
1 change: 1 addition & 0 deletions tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ tracing = { workspace = true }

[dev-dependencies]
test-log = { version = "0.2.17", features = ["trace"] }
turso_macros.workspace = true

[features]
default = ["test_helper"]
Expand Down
Loading
Loading