Skip to content

Commit

Permalink
Clean up derive macro codegen.
Browse files Browse the repository at this point in the history
  • Loading branch information
wojciech-graj committed May 18, 2024
1 parent 53a5f01 commit 915730d
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 230 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# v0.4.0
- Delete `EnumExt`
- Bump rust version to 2021
- Bump dependencies, and rust version to 2021
- Make lifetime generics work
- Handle context using generics instead of `Any`
- Improve derive macro hygiene
- Improve derive macro error reporting
# v0.3.4
- Do not trigger https://github.com/rust-lang/rust/issues/120363 with generated code
# v0.3.3
Expand Down
4 changes: 2 additions & 2 deletions bin-proto-derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ proc-macro = true

[dependencies]
syn = { version = "1.0.109", features = ["default", "extra-traits", "parsing"] }
quote = "1.0.35"
proc-macro2 = "1.0.79"
quote = "1.0.36"
proc-macro2 = "1.0.82"
217 changes: 129 additions & 88 deletions bin-proto-derive/src/attr.rs
Original file line number Diff line number Diff line change
@@ -1,76 +1,114 @@
use proc_macro2::TokenStream;
use syn::{parse::Parser, punctuated::Punctuated, token::Add, TypeParamBound};
use proc_macro2::{Span, TokenStream};
use syn::{parse::Parser, punctuated::Punctuated, spanned::Spanned, token::Add, Error, Result};

#[derive(Debug, Default)]
pub struct Attrs {
pub discriminant_type: Option<syn::Type>,
pub discriminant: Option<syn::Expr>,
pub ctx: Option<syn::Type>,
pub ctx_bounds: Option<Punctuated<TypeParamBound, Add>>,
pub ctx_bounds: Option<Punctuated<syn::TypeParamBound, Add>>,
pub write_value: Option<syn::Expr>,
pub bits: Option<u32>,
pub flexible_array_member: bool,
pub length: Option<syn::Expr>,
}

impl Attrs {
pub fn validate_enum(&self) {
pub fn validate_enum(&self, span: Span) -> Result<()> {
if self.discriminant_type.is_none() {
panic!("expected discriminant_type attribute for enum")
return Err(Error::new(
span,
"expected discriminant_type attribute for enum",
));
}
if self.discriminant.is_some() {
panic!("unexpected discriminant attribute for enum")
return Err(Error::new(
span,
"unexpected discriminant attribute for enum",
));
}
if self.ctx.is_some() && self.ctx_bounds.is_some() {
panic!("cannot specify ctx and ctx_bounds simultaneously")
return Err(Error::new(
span,
"cannot specify ctx and ctx_bounds simultaneously",
));
}
if self.write_value.is_some() {
panic!("unexpected write_value attribute for enum")
return Err(Error::new(
span,
"unexpected write_value attribute for enum",
));
}
if self.flexible_array_member {
panic!("unexpected flexible_array_member attribute for enum")
return Err(Error::new(
span,
"unexpected flexible_array_member attribute for enum",
));
}
if self.length.is_some() {
panic!("unexpected length attribute for enum")
return Err(Error::new(span, "unexpected length attribute for enum"));
}
Ok(())
}

pub fn validate_variant(&self) {
pub fn validate_variant(&self, span: Span) -> Result<()> {
if self.discriminant_type.is_some() {
panic!("unexpected discriminant_type attribute for variant")
return Err(Error::new(
span,
"unexpected discriminant_type attribute for variant",
));
}
if self.ctx.is_some() {
panic!("unexpected ctx attribute for variant")
return Err(Error::new(span, "unexpected ctx attribute for variant"));
}
if self.ctx_bounds.is_some() {
panic!("unexpected ctx_bounds attribute for variant")
return Err(Error::new(
span,
"unexpected ctx_bounds attribute for variant",
));
}
if self.write_value.is_some() {
panic!("unexpected write_value attribute for variant")
return Err(Error::new(
span,
"unexpected write_value attribute for variant",
));
}
if self.bits.is_some() {
panic!("unexpected bits attribute for variant")
return Err(Error::new(span, "unexpected bits attribute for variant"));
}
if self.flexible_array_member {
panic!("unexpected flexible_array_member attribute for variant")
return Err(Error::new(
span,
"unexpected flexible_array_member attribute for variant",
));
}
if self.length.is_some() {
panic!("unexpected length attribute for variant")
return Err(Error::new(span, "unexpected length attribute for variant"));
}
Ok(())
}

pub fn validate_field(&self) {
pub fn validate_field(&self, span: Span) -> Result<()> {
if self.discriminant_type.is_some() {
panic!("unexpected discriminant_type attribute for field")
return Err(Error::new(
span,
"unexpected discriminant_type attribute for field",
));
}
if self.discriminant.is_some() {
panic!("unexpected discriminant attribute for field")
return Err(Error::new(
span,
"unexpected discriminant attribute for field",
));
}
if self.ctx.is_some() {
panic!("unexpected ctx attribute for variant")
return Err(Error::new(span, "unexpected ctx attribute for variant"));
}
if self.ctx_bounds.is_some() {
panic!("unexpected ctx_bounds attribute for variant")
return Err(Error::new(
span,
"unexpected ctx_bounds attribute for variant",
));
}
if [
self.bits.is_some(),
Expand All @@ -82,24 +120,30 @@ impl Attrs {
.count()
> 1
{
panic!("bits, flexible_array_member, and length are mutually-exclusive attributes")
return Err(Error::new(
span,
"bits, flexible_array_member, and length are mutually-exclusive attributes",
));
}
Ok(())
}

pub fn ctx_tok(&self) -> TokenStream {
pub fn ctx_ty(&self) -> TokenStream {
self.ctx
.clone()
.as_ref()
.map(|ctx| quote!(#ctx))
.unwrap_or(quote!(__Ctx))
}
}

impl From<&[syn::Attribute]> for Attrs {
fn from(value: &[syn::Attribute]) -> Self {
impl TryFrom<&[syn::Attribute]> for Attrs {
type Error = syn::Error;

fn try_from(value: &[syn::Attribute]) -> Result<Self> {
let meta_lists = value.iter().filter_map(|attr| match attr.parse_meta() {
Ok(syn::Meta::List(meta_list)) => {
if meta_list.path.get_ident()
== Some(&syn::Ident::new("protocol", proc_macro2::Span::call_site()))
== Some(&syn::Ident::new("protocol", Span::call_site()))
{
Some(meta_list)
} else {
Expand All @@ -111,87 +155,84 @@ impl From<&[syn::Attribute]> for Attrs {

let mut attribs = Attrs::default();
for meta_list in meta_lists {
for meta in meta_list.nested {
for meta in meta_list.nested.iter() {
match meta {
syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) => {
match name_value.path.get_ident() {
Some(ident) => match &ident.to_string()[..] {
"discriminant_type" => {
attribs.discriminant_type =
Some(meta_name_value_to_parse(name_value))
}
"discriminant" => {
attribs.discriminant =
Some(meta_name_value_to_parse(name_value))
}
"ctx" => attribs.ctx = Some(meta_name_value_to_parse(name_value)),
"ctx_bounds" => {
attribs.ctx_bounds =
Some(meta_name_value_to_punctuated(name_value))
}
"bits" => attribs.bits = Some(meta_name_value_to_u32(name_value)),
"write_value" => {
attribs.write_value = Some(meta_name_value_to_parse(name_value))
}
"length" => {
attribs.length = Some(meta_name_value_to_parse(name_value))
}
ident => panic!("unrecognised #[protocol({})]", ident),
},
None => panic!("failed to parse #[protocol(...)]"),
syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) => match name_value
.path
.get_ident()
{
Some(ident) => match ident.to_string().as_str() {
"discriminant_type" => {
attribs.discriminant_type =
Some(meta_name_value_to_parse(name_value)?)
}
"discriminant" => {
attribs.discriminant = Some(meta_name_value_to_parse(name_value)?)
}
"ctx" => attribs.ctx = Some(meta_name_value_to_parse(name_value)?),
"ctx_bounds" => {
attribs.ctx_bounds =
Some(meta_name_value_to_punctuated(name_value)?)
}
"bits" => attribs.bits = Some(meta_name_value_to_u32(name_value)?),
"write_value" => {
attribs.write_value = Some(meta_name_value_to_parse(name_value)?)
}
"length" => {
attribs.length = Some(meta_name_value_to_parse(name_value)?)
}
_ => {
return Err(Error::new(meta_list.span(), "unrecognised attribute"))
}
},
None => {
return Err(Error::new(meta_list.span(), "failed to parse attribute"))
}
}
},
syn::NestedMeta::Meta(syn::Meta::Path(path)) => match path.get_ident() {
Some(ident) => match ident.to_string().as_str() {
"flexible_array_member" => attribs.flexible_array_member = true,
_ => panic!("unrecognised #[protocol({})]", ident),
_ => {
return Err(Error::new(meta_list.span(), "unrecognised attribute"))
}
},
None => panic!("parsed string was not an identifier"),
None => {
return Err(Error::new(path.get_ident().span(), "expected identifier"))
}
},
_ => {
panic!("unrecognized #[protocol(..)] attribute")
}
_ => return Err(Error::new(meta_list.span(), "unrecognised attribute")),
};
}
}
attribs
Ok(attribs)
}
}

fn meta_name_value_to_parse<T: syn::parse::Parse>(name_value: syn::MetaNameValue) -> T {
fn meta_name_value_to_parse<T: syn::parse::Parse>(name_value: &syn::MetaNameValue) -> Result<T> {
match name_value.lit {
syn::Lit::Str(s) => match syn::parse_str::<T>(s.value().as_str()) {
Ok(f) => f,
Err(_) => {
panic!("Failed to parse '{}'", s.value())
}
},
_ => panic!("#[protocol(... = \"...\")] must be string"),
syn::Lit::Str(ref s) => syn::parse_str::<T>(s.value().as_str())
.map_err(|_| Error::new(name_value.span(), "Failed to parse")),

_ => Err(Error::new(name_value.span(), "Expected string")),
}
}

fn meta_name_value_to_u32(name_value: syn::MetaNameValue) -> u32 {
fn meta_name_value_to_u32(name_value: &syn::MetaNameValue) -> Result<u32> {
match name_value.lit {
syn::Lit::Int(i) => match i.base10_parse() {
Ok(i) => i,
Err(_) => {
panic!("Failed to parse integer from '{}'", i)
}
},
_ => panic!("bitfield size must be an integer"),
syn::Lit::Int(ref i) => i
.base10_parse()
.map_err(|_| Error::new(name_value.span(), "Failed to parse u32")),
_ => Err(Error::new(name_value.span(), "Expected integer")),
}
}

fn meta_name_value_to_punctuated<T: syn::parse::Parse, P: syn::parse::Parse>(
name_value: syn::MetaNameValue,
) -> Punctuated<T, P> {
name_value: &syn::MetaNameValue,
) -> Result<Punctuated<T, P>> {
match name_value.lit {
syn::Lit::Str(s) => match Punctuated::parse_terminated.parse_str(s.value().as_str()) {
Ok(f) => f,
Err(_) => {
panic!("Failed to parse '{}'", s.value())
}
},
_ => panic!("#[protocol(... = \"...\")] must be string"),
syn::Lit::Str(ref s) => Punctuated::parse_terminated
.parse_str(s.value().as_str())
.map_err(|_| Error::new(name_value.span(), "Failed to parse")),
_ => Err(Error::new(name_value.span(), "Expected string")),
}
}
Loading

0 comments on commit 915730d

Please sign in to comment.