Skip to content

Commit f70e206

Browse files
committed
Use enum to configure sync write mode
1 parent b352cdd commit f70e206

File tree

3 files changed

+52
-51
lines changed

3 files changed

+52
-51
lines changed

cached_proc_macro/src/cached.rs

+33-38
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,18 @@ use darling::ast::NestedMeta;
33
use darling::FromMeta;
44
use proc_macro::TokenStream;
55
use quote::quote;
6+
use std::cmp::PartialEq;
67
use syn::spanned::Spanned;
78
use syn::{parse_macro_input, parse_str, Block, Ident, ItemFn, ReturnType, Type};
89

10+
#[derive(Debug, Default, FromMeta, Eq, PartialEq)]
11+
enum SyncWriteMode {
12+
#[default]
13+
Disabled,
14+
Default,
15+
ByKey,
16+
}
17+
918
#[derive(FromMeta)]
1019
struct MacroArgs {
1120
#[darling(default)]
@@ -27,9 +36,7 @@ struct MacroArgs {
2736
#[darling(default)]
2837
option: bool,
2938
#[darling(default)]
30-
sync_writes: bool,
31-
#[darling(default)]
32-
sync_writes_by_key: bool,
39+
sync_writes: SyncWriteMode,
3340
#[darling(default)]
3441
with_cached_flag: bool,
3542
#[darling(default)]
@@ -192,16 +199,8 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream {
192199
_ => panic!("the result and option attributes are mutually exclusive"),
193200
};
194201

195-
if args.result_fallback && args.sync_writes {
196-
panic!("the result_fallback and sync_writes attributes are mutually exclusive");
197-
}
198-
199-
if args.result_fallback && args.sync_writes_by_key {
200-
panic!("the result_fallback and sync_writes_by_key attributes are mutually exclusive");
201-
}
202-
203-
if args.sync_writes && args.sync_writes_by_key {
204-
panic!("the sync_writes and sync_writes_by_key attributes are mutually exclusive");
202+
if args.result_fallback && args.sync_writes != SyncWriteMode::Disabled {
203+
panic!("result_fallback and sync_writes are mutually exclusive");
205204
}
206205

207206
let set_cache_and_return = quote! {
@@ -216,20 +215,19 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream {
216215
let function_call;
217216
let ty;
218217
if asyncness.is_some() {
219-
lock = if args.sync_writes_by_key {
220-
quote! {
218+
lock = match args.sync_writes {
219+
SyncWriteMode::ByKey => quote! {
221220
let mut locks = #cache_ident.lock().await;
222221
let lock = locks
223222
.entry(key.clone())
224223
.or_insert_with(|| std::sync::Arc::new(::cached::async_sync::Mutex::new(#cache_create)))
225224
.clone();
226225
drop(locks);
227226
let mut cache = lock.lock().await;
228-
}
229-
} else {
230-
quote! {
227+
},
228+
_ => quote! {
231229
let mut cache = #cache_ident.lock().await;
232-
}
230+
},
233231
};
234232

235233
function_no_cache = quote! {
@@ -240,27 +238,25 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream {
240238
let result = #no_cache_fn_ident(#(#input_names),*).await;
241239
};
242240

243-
ty = if args.sync_writes_by_key {
244-
quote! {
241+
ty = match args.sync_writes {
242+
SyncWriteMode::ByKey => quote! {
245243
#visibility static #cache_ident: ::cached::once_cell::sync::Lazy<::cached::async_sync::Mutex<std::collections::HashMap<#cache_key_ty, std::sync::Arc<::cached::async_sync::Mutex<#cache_ty>>>>> = ::cached::once_cell::sync::Lazy::new(|| ::cached::async_sync::Mutex::new(std::collections::HashMap::new()));
246-
}
247-
} else {
248-
quote! {
244+
},
245+
_ => quote! {
249246
#visibility static #cache_ident: ::cached::once_cell::sync::Lazy<::cached::async_sync::Mutex<#cache_ty>> = ::cached::once_cell::sync::Lazy::new(|| ::cached::async_sync::Mutex::new(#cache_create));
250-
}
247+
},
251248
};
252249
} else {
253-
lock = if args.sync_writes_by_key {
254-
quote! {
250+
lock = match args.sync_writes {
251+
SyncWriteMode::ByKey => quote! {
255252
let mut locks = #cache_ident.lock().unwrap();
256253
let lock = locks.entry(key.clone()).or_insert_with(|| std::sync::Arc::new(std::sync::Mutex::new(#cache_create))).clone();
257254
drop(locks);
258255
let mut cache = lock.lock().unwrap();
259-
}
260-
} else {
261-
quote! {
256+
},
257+
_ => quote! {
262258
let mut cache = #cache_ident.lock().unwrap();
263-
}
259+
},
264260
};
265261

266262
function_no_cache = quote! {
@@ -271,14 +267,13 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream {
271267
let result = #no_cache_fn_ident(#(#input_names),*);
272268
};
273269

274-
ty = if args.sync_writes_by_key {
275-
quote! {
270+
ty = match args.sync_writes {
271+
SyncWriteMode::ByKey => quote! {
276272
#visibility static #cache_ident: ::cached::once_cell::sync::Lazy<std::sync::Mutex<std::collections::HashMap<#cache_key_ty, std::sync::Arc<std::sync::Mutex<#cache_ty>>>>> = ::cached::once_cell::sync::Lazy::new(|| std::sync::Mutex::new(std::collections::HashMap::new()));
277-
}
278-
} else {
279-
quote! {
273+
},
274+
_ => quote! {
280275
#visibility static #cache_ident: ::cached::once_cell::sync::Lazy<std::sync::Mutex<#cache_ty>> = ::cached::once_cell::sync::Lazy::new(|| std::sync::Mutex::new(#cache_create));
281-
}
276+
},
282277
}
283278
}
284279

@@ -290,7 +285,7 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream {
290285
#set_cache_and_return
291286
};
292287

293-
let do_set_return_block = if args.sync_writes_by_key || args.sync_writes {
288+
let do_set_return_block = if args.sync_writes != SyncWriteMode::Disabled {
294289
quote! {
295290
#lock
296291
if let Some(result) = cache.cache_get(&key) {

cached_proc_macro/src/once.rs

+13-7
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,21 @@ use quote::quote;
66
use syn::spanned::Spanned;
77
use syn::{parse_macro_input, Ident, ItemFn, ReturnType};
88

9+
#[derive(Debug, Default, FromMeta)]
10+
enum SyncWriteMode {
11+
#[default]
12+
Disabled,
13+
Default,
14+
}
15+
916
#[derive(FromMeta)]
1017
struct OnceMacroArgs {
1118
#[darling(default)]
1219
name: Option<String>,
1320
#[darling(default)]
1421
time: Option<u64>,
1522
#[darling(default)]
16-
sync_writes: bool,
23+
sync_writes: SyncWriteMode,
1724
#[darling(default)]
1825
result: bool,
1926
#[darling(default)]
@@ -220,23 +227,22 @@ pub fn once(args: TokenStream, input: TokenStream) -> TokenStream {
220227
}
221228
};
222229

223-
let do_set_return_block = if args.sync_writes {
224-
quote! {
230+
let do_set_return_block = match args.sync_writes {
231+
SyncWriteMode::Default => quote! {
225232
#r_lock_return_cache_block
226233
#w_lock
227234
if let Some(result) = &*cached {
228235
#return_cache_block
229236
}
230237
#function_call
231238
#set_cache_and_return
232-
}
233-
} else {
234-
quote! {
239+
},
240+
SyncWriteMode::Disabled => quote! {
235241
#r_lock_return_cache_block
236242
#function_call
237243
#w_lock
238244
#set_cache_and_return
239-
}
245+
},
240246
};
241247

242248
let signature_no_muts = get_mut_signature(signature);

tests/cached.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,7 @@ async fn test_only_cached_option_once_per_second_a() {
848848
/// to return the cached result of the one call instead of all
849849
/// concurrently un-cached tasks executing and writing concurrently.
850850
#[cfg(feature = "async")]
851-
#[once(time = 2, sync_writes = true)]
851+
#[once(time = 2, sync_writes = "default")]
852852
async fn only_cached_once_per_second_sync_writes(s: String) -> Vec<String> {
853853
vec![s]
854854
}
@@ -862,7 +862,7 @@ async fn test_only_cached_once_per_second_sync_writes() {
862862
assert_eq!(a.await.unwrap(), b.await.unwrap());
863863
}
864864

865-
#[cached(time = 2, sync_writes = true, key = "u32", convert = "{ 1 }")]
865+
#[cached(time = 2, sync_writes = "default", key = "u32", convert = "{ 1 }")]
866866
fn cached_sync_writes(s: String) -> Vec<String> {
867867
vec![s]
868868
}
@@ -881,7 +881,7 @@ fn test_cached_sync_writes() {
881881
}
882882

883883
#[cfg(feature = "async")]
884-
#[cached(time = 2, sync_writes = true, key = "u32", convert = "{ 1 }")]
884+
#[cached(time = 2, sync_writes = "default", key = "u32", convert = "{ 1 }")]
885885
async fn cached_sync_writes_a(s: String) -> Vec<String> {
886886
vec![s]
887887
}
@@ -898,7 +898,7 @@ async fn test_cached_sync_writes_a() {
898898
assert_eq!(a, c.await.unwrap());
899899
}
900900

901-
#[cached(time = 2, sync_writes_by_key = true, key = "u32", convert = "{ 1 }")]
901+
#[cached(time = 2, sync_writes = "by_key", key = "u32", convert = "{ 1 }")]
902902
fn cached_sync_writes_by_key(s: String) -> Vec<String> {
903903
sleep(Duration::new(1, 0));
904904
vec![s]
@@ -919,7 +919,7 @@ fn test_cached_sync_writes_by_key() {
919919
#[cfg(feature = "async")]
920920
#[cached(
921921
time = 5,
922-
sync_writes_by_key = true,
922+
sync_writes = "by_key",
923923
key = "String",
924924
convert = r#"{ format!("{}", s) }"#
925925
)]
@@ -942,7 +942,7 @@ async fn test_cached_sync_writes_by_key_a() {
942942
}
943943

944944
#[cfg(feature = "async")]
945-
#[once(sync_writes = true)]
945+
#[once(sync_writes = "default")]
946946
async fn once_sync_writes_a(s: &tokio::sync::Mutex<String>) -> String {
947947
let mut guard = s.lock().await;
948948
let results: String = (*guard).clone().to_string();

0 commit comments

Comments
 (0)