diff --git a/derive-encode/Cargo.toml b/derive-encode/Cargo.toml index e24e1933..a9930ccd 100644 --- a/derive-encode/Cargo.toml +++ b/derive-encode/Cargo.toml @@ -20,4 +20,4 @@ syn = "1" prometheus-client = { path = "../", features = ["protobuf"] } [lib] -proc-macro = true \ No newline at end of file +proc-macro = true diff --git a/derive-encode/src/lib.rs b/derive-encode/src/lib.rs index e126889e..62bc3464 100644 --- a/derive-encode/src/lib.rs +++ b/derive-encode/src/lib.rs @@ -9,7 +9,7 @@ use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::quote; -use syn::DeriveInput; +use syn::{parse::Parse, DeriveInput, Ident, LitStr, Token}; /// Derive `prometheus_client::encoding::EncodeLabelSet`. #[proc_macro_derive(EncodeLabelSet, attributes(prometheus))] @@ -88,22 +88,74 @@ pub fn derive_encode_label_set(input: TokenStream) -> TokenStream { } /// Derive `prometheus_client::encoding::EncodeLabelValue`. -#[proc_macro_derive(EncodeLabelValue)] +/// +/// This macro only applies to `enum`s and will panic if you attempt to use it on structs. +/// +/// At the enum level you can use `#[prometheus(value_case = "lower")]` or `"upper"` to set the +/// default case of the enum variants. +/// +/// ```rust +/// # use prometheus_client::encoding::EncodeLabelValue; +/// #[derive(Clone, Hash, PartialEq, Eq, EncodeLabelValue, Debug)] +/// #[prometheus(value_case = "upper")] +/// enum Method { +/// Get, +/// Put, +/// } +/// ``` +/// +/// Will encode to label values "GET" and "PUT" in prometheus metrics. +/// +/// For variants you can use `#[prometheus(lower)]` or `#[prometheus(upper)]` to set the case for +/// only that variant. +#[proc_macro_derive(EncodeLabelValue, attributes(prometheus))] pub fn derive_encode_label_value(input: TokenStream) -> TokenStream { let ast: DeriveInput = syn::parse(input).unwrap(); let name = &ast.ident; + let config: LabelConfig = ast + .attrs + .iter() + .find_map(|attr| { + if attr.path.is_ident("prometheus") { + match attr.parse_args::() { + Ok(config) => Some(config), + Err(e) => panic!("invalid prometheus attribute: {e}"), + } + } else { + None + } + }) + .unwrap_or_default(); + let body = match ast.clone().data { syn::Data::Struct(_) => { - panic!("Can not derive EncodeLabel for struct.") + panic!("Can not derive EncodeLabelValue for struct.") } syn::Data::Enum(syn::DataEnum { variants, .. }) => { let match_arms: TokenStream2 = variants .into_iter() .map(|v| { let ident = v.ident; + + let attribute = v + .attrs + .iter() + .find(|a| a.path.is_ident("prometheus")) + .map(|a| a.parse_args::().unwrap().to_string()); + let case = match attribute.as_deref() { + Some("lower") => ValueCase::Lower, + Some("upper") => ValueCase::Upper, + Some(other) => { + panic!("Provided attribute '{other}', but only 'lower' and 'upper' are supported") + } + None => config.value_case.clone(), + }; + + let value = case.apply(&ident); + quote! { - #name::#ident => encoder.write_str(stringify!(#ident))?, + #name::#ident => encoder.write_str(stringify!(#value))?, } }) .collect(); @@ -114,7 +166,7 @@ pub fn derive_encode_label_value(input: TokenStream) -> TokenStream { } } } - syn::Data::Union(_) => panic!("Can not derive Encode for union."), + syn::Data::Union(_) => panic!("Can not derive EncodeLabelValue for union."), }; let gen = quote! { @@ -132,6 +184,77 @@ pub fn derive_encode_label_value(input: TokenStream) -> TokenStream { gen.into() } +#[derive(Clone)] +enum ValueCase { + Lower, + Upper, + NoChange, +} + +impl ValueCase { + fn apply(&self, ident: &Ident) -> Ident { + match self { + ValueCase::Lower => Ident::new(&ident.to_string().to_lowercase(), ident.span()), + ValueCase::Upper => Ident::new(&ident.to_string().to_uppercase(), ident.span()), + ValueCase::NoChange => ident.clone(), + } + } +} + +struct LabelConfig { + value_case: ValueCase, +} + +impl Default for LabelConfig { + fn default() -> Self { + Self { + value_case: ValueCase::NoChange, + } + } +} + +impl Parse for LabelConfig { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut config = LabelConfig::default(); + + while input.peek(Ident) { + let ident: Ident = input.parse()?; + + match ident.to_string().as_str() { + "value_case" => { + let _: Token![=] = input.parse()?; + let case: LitStr = input.parse()?; + + match case.value().as_str() { + "lower" => config.value_case = ValueCase::Lower, + "upper" => config.value_case = ValueCase::Upper, + invalid => { + return Err(syn::Error::new( + case.span(), + format!( + "value case may only be \"lower\" or \"upper\", not \"{invalid}\"" + ), + )) + } + } + } + invalid => { + return Err(syn::Error::new( + ident.span(), + format!("invalid prometheus attribute \"{invalid}\""), + )) + } + } + + if input.peek(Token![,]) { + let _: Token![,] = input.parse()?; + } + } + + Ok(config) + } +} + // Copied from https://github.com/djc/askama (MIT and APACHE licensed) and // modified. static KEYWORD_IDENTIFIERS: [(&str, &str); 48] = [ diff --git a/derive-encode/tests/lib.rs b/derive-encode/tests/lib.rs index fba8412d..2c9da863 100644 --- a/derive-encode/tests/lib.rs +++ b/derive-encode/tests/lib.rs @@ -173,3 +173,104 @@ fn flatten() { + "# EOF\n"; assert_eq!(expected, buffer); } + +#[test] +fn case_per_label() { + #[derive(EncodeLabelSet, Hash, Clone, Eq, PartialEq, Debug)] + struct Labels { + lower: EnumLabel, + upper: EnumLabel, + no_change: EnumLabel, + } + + #[derive(EncodeLabelValue, Hash, Clone, Eq, PartialEq, Debug)] + enum EnumLabel { + #[prometheus(lower)] + One, + #[prometheus(upper)] + Two, + Three, + } + + let mut registry = Registry::default(); + let family = Family::::default(); + registry.register("my_counter", "This is my counter", family.clone()); + + // Record a single HTTP GET request. + family + .get_or_create(&Labels { + lower: EnumLabel::One, + upper: EnumLabel::Two, + no_change: EnumLabel::Three, + }) + .inc(); + + // Encode all metrics in the registry in the text format. + let mut buffer = String::new(); + encode(&mut buffer, ®istry).unwrap(); + + let expected = "# HELP my_counter This is my counter.\n".to_owned() + + "# TYPE my_counter counter\n" + + "my_counter_total{lower=\"one\",upper=\"TWO\",no_change=\"Three\"} 1\n" + + "# EOF\n"; + assert_eq!(expected, buffer); +} + +#[test] +fn case_whole_enum() { + #[derive(EncodeLabelSet, Hash, Clone, Eq, PartialEq, Debug)] + struct Labels { + lower: EnumLowerLabel, + upper: EnumUpperLabel, + no_change: EnumNoChangeLabel, + override_case: EnumOverrideLabel, + } + + #[derive(EncodeLabelValue, Hash, Clone, Eq, PartialEq, Debug)] + #[prometheus(value_case = "lower")] + enum EnumLowerLabel { + One, + } + + #[derive(EncodeLabelValue, Hash, Clone, Eq, PartialEq, Debug)] + #[prometheus(value_case = "upper")] + enum EnumUpperLabel { + Two, + } + + #[derive(EncodeLabelValue, Hash, Clone, Eq, PartialEq, Debug)] + enum EnumNoChangeLabel { + Three, + } + + #[derive(EncodeLabelValue, Hash, Clone, Eq, PartialEq, Debug)] + #[prometheus(value_case = "upper")] + enum EnumOverrideLabel { + #[prometheus(lower)] + Four, + } + + let mut registry = Registry::default(); + let family = Family::::default(); + registry.register("my_counter", "This is my counter", family.clone()); + + // Record a single HTTP GET request. + family + .get_or_create(&Labels { + lower: EnumLowerLabel::One, + upper: EnumUpperLabel::Two, + no_change: EnumNoChangeLabel::Three, + override_case: EnumOverrideLabel::Four, + }) + .inc(); + + // Encode all metrics in the registry in the text format. + let mut buffer = String::new(); + encode(&mut buffer, ®istry).unwrap(); + + let expected = "# HELP my_counter This is my counter.\n".to_owned() + + "# TYPE my_counter counter\n" + + "my_counter_total{lower=\"one\",upper=\"TWO\",no_change=\"Three\",override_case=\"four\"} 1\n" + + "# EOF\n"; + assert_eq!(expected, buffer); +}