|
1 | 1 | use proc_macro::{Span, TokenStream}; |
2 | | -use quote::{quote, TokenStreamExt}; |
3 | | -use syn::{parse_macro_input, DeriveInput, Fields, Ident, Index, Lit}; |
| 2 | +use quote::quote; |
| 3 | +use syn::{parse_macro_input, parse_quote, DeriveInput, Fields, Ident, Index, Lit}; |
4 | 4 |
|
5 | 5 | #[proc_macro_derive(Packet, attributes(packet))] |
6 | 6 | pub fn derive_packet(input: TokenStream) -> TokenStream { |
@@ -182,3 +182,122 @@ pub fn derive_encode(input: TokenStream) -> TokenStream { |
182 | 182 | } |
183 | 183 | .into() |
184 | 184 | } |
| 185 | + |
| 186 | +/// Automatically implements "straight-across" decoding for the given struct, i.e. fields are |
| 187 | +/// deserialized in order as is. Supports #[varint] and #[varlong] attributes on integer types to |
| 188 | +/// deserialize as those formats instead. |
| 189 | +#[proc_macro_derive(Decode, attributes(varint, varlong))] |
| 190 | +pub fn derive_decode(input: TokenStream) -> TokenStream { |
| 191 | + let input = parse_macro_input!(input as DeriveInput); |
| 192 | + |
| 193 | + let syn::Data::Struct(data) = input.data else { |
| 194 | + panic!("Can only derive Decode on a struct"); |
| 195 | + }; |
| 196 | + |
| 197 | + let name = input.ident; |
| 198 | + |
| 199 | + let struct_tokens = match data.fields { |
| 200 | + Fields::Named(fields) => { |
| 201 | + let mut field_tokens = proc_macro2::TokenStream::new(); |
| 202 | + |
| 203 | + for field in fields.named { |
| 204 | + let field_name = field.ident.expect("couldn't get ident for named field"); |
| 205 | + let ty = field.ty; |
| 206 | + |
| 207 | + let wrapped = format!("for field {field_name} in {name}"); |
| 208 | + |
| 209 | + if field |
| 210 | + .attrs |
| 211 | + .iter() |
| 212 | + .any(|attr| attr.meta.path().is_ident("varint")) |
| 213 | + { |
| 214 | + field_tokens.extend(quote! { |
| 215 | + #field_name: VarInt::decode(r) |
| 216 | + .wrap_err(#wrapped)? |
| 217 | + .try_into()?, |
| 218 | + }); |
| 219 | + } else if field |
| 220 | + .attrs |
| 221 | + .iter() |
| 222 | + .any(|attr| attr.meta.path().is_ident("varlong")) |
| 223 | + { |
| 224 | + field_tokens.extend(quote! { |
| 225 | + #field_name: VarLong::decode(r) |
| 226 | + .wrap_err(#wrapped)? |
| 227 | + .try_into()?, |
| 228 | + }); |
| 229 | + } else { |
| 230 | + field_tokens.extend(quote! { |
| 231 | + #field_name: <#ty as Decode>::decode(r) |
| 232 | + .wrap_err(#wrapped)?, |
| 233 | + }); |
| 234 | + } |
| 235 | + } |
| 236 | + quote! { |
| 237 | + Self { |
| 238 | + #field_tokens |
| 239 | + } |
| 240 | + } |
| 241 | + } |
| 242 | + Fields::Unnamed(fields) => { |
| 243 | + let mut field_tokens = proc_macro2::TokenStream::new(); |
| 244 | + for (i, field) in fields.unnamed.into_iter().enumerate() { |
| 245 | + let ty = field.ty; |
| 246 | + |
| 247 | + let wrapped = format!("for field {i} in {name}"); |
| 248 | + |
| 249 | + if field |
| 250 | + .attrs |
| 251 | + .iter() |
| 252 | + .any(|attr| attr.meta.path().is_ident("varint")) |
| 253 | + { |
| 254 | + field_tokens.extend(quote! { |
| 255 | + VarInt::decode(r) |
| 256 | + .wrap_err(#wrapped)? |
| 257 | + .try_into()?, |
| 258 | + }); |
| 259 | + } else if field |
| 260 | + .attrs |
| 261 | + .iter() |
| 262 | + .any(|attr| attr.meta.path().is_ident("varlong")) |
| 263 | + { |
| 264 | + field_tokens.extend(quote! { |
| 265 | + VarLong::decode(r) |
| 266 | + .wrap_err(#wrapped)? |
| 267 | + .try_into()?, |
| 268 | + }); |
| 269 | + } else { |
| 270 | + field_tokens.extend(quote! { |
| 271 | + <#ty as Decode>::decode(r) |
| 272 | + .wrap_err(#wrapped)?, |
| 273 | + }); |
| 274 | + } |
| 275 | + } |
| 276 | + quote! { |
| 277 | + Self(#field_tokens) |
| 278 | + } |
| 279 | + } |
| 280 | + Fields::Unit => quote! { Self }, |
| 281 | + }; |
| 282 | + |
| 283 | + let struct_generics = input.generics; |
| 284 | + let where_clause = struct_generics.where_clause.clone(); |
| 285 | + |
| 286 | + let mut impl_generics = struct_generics.clone(); |
| 287 | + if impl_generics.lifetimes().count() == 0 { |
| 288 | + impl_generics.params.push(parse_quote!('a)); |
| 289 | + } |
| 290 | + |
| 291 | + quote! { |
| 292 | + impl #impl_generics Decode #impl_generics for #name #struct_generics #where_clause { |
| 293 | + fn decode(r: &mut &'a [u8]) -> color_eyre::Result<Self> |
| 294 | + where |
| 295 | + Self: Sized, |
| 296 | + { |
| 297 | + use color_eyre::eyre::WrapErr; |
| 298 | + Ok(#struct_tokens) |
| 299 | + } |
| 300 | + } |
| 301 | + } |
| 302 | + .into() |
| 303 | +} |
0 commit comments