Skip to content

Commit

Permalink
Implement ctx_generics attribute.
Browse files Browse the repository at this point in the history
  • Loading branch information
wojciech-graj committed Nov 10, 2024
1 parent c11be67 commit 644026f
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 26 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# v0.6.0
- Allow multiple attributes in a single `#[protocol(...)]`
- Require unquoted expressions in attributes
- Use nested metas for all lists in attributes
- Add `#[protocol(ctx_generics(...))]`
- Impose `non_exhaustive` on `Error`
# v0.5.0
- Split `Protocol` into `ProtocolRead` and `ProtocolWrite`
Expand Down
26 changes: 22 additions & 4 deletions bin-proto-derive/src/attr.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use proc_macro2::{Span, TokenStream};
use syn::{punctuated::Punctuated, spanned::Spanned, token::Plus, Error, Result};
use syn::{parenthesized, punctuated::Punctuated, spanned::Spanned, token::Comma, Error, Result};

#[derive(Default)]
pub struct Attrs {
pub discriminant_type: Option<syn::Type>,
pub discriminant: Option<syn::Expr>,
pub ctx: Option<syn::Type>,
pub ctx_bounds: Option<Punctuated<syn::TypeParamBound, Plus>>,
pub ctx_generics: Option<Vec<syn::GenericParam>>,
pub ctx_bounds: Option<Vec<syn::TypeParamBound>>,
pub write_value: Option<syn::Expr>,
pub bits: Option<syn::Expr>,
pub flexible_array_member: bool,
Expand Down Expand Up @@ -165,9 +166,26 @@ impl TryFrom<&[syn::Attribute]> for Attrs {
attribs.discriminant = Some(meta.value()?.parse()?);
} else if meta.path.is_ident("ctx") {
attribs.ctx = Some(meta.value()?.parse()?);
} else if meta.path.is_ident("ctx_generics") {
let content;
parenthesized!(content in meta.input);
attribs.ctx_generics = Some(
Punctuated::<syn::GenericParam, Comma>::parse_separated_nonempty(
&content,
)?
.into_iter()
.collect(),
);
} else if meta.path.is_ident("ctx_bounds") {
attribs.ctx_bounds =
Some(Punctuated::parse_separated_nonempty(meta.value()?)?);
let content;
parenthesized!(content in meta.input);
attribs.ctx_bounds = Some(
Punctuated::<syn::TypeParamBound, Comma>::parse_separated_nonempty(
&content,
)?
.into_iter()
.collect(),
);
} else if meta.path.is_ident("bits") {
attribs.bits = Some(meta.value()?.parse()?);
} else if meta.path.is_ident("write_value") {
Expand Down
4 changes: 2 additions & 2 deletions bin-proto-derive/src/codegen/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ pub fn variant_discriminant(plan: &plan::Enum, attribs: &Attrs) -> TokenStream {
})
}

pub fn read_variant_fields(plan: &plan::Enum, attribs: &Attrs) -> TokenStream {
pub fn read_variant_fields(plan: &plan::Enum) -> TokenStream {
let discriminant_match_branches = plan.variants.iter().map(|variant| {
let variant_name = &variant.ident;
let discriminant_literal = &variant.discriminant_value;
let (reader, initializer) = codegen::reads(&variant.fields, attribs);
let (reader, initializer) = codegen::reads(&variant.fields);

quote!(
#discriminant_literal => {
Expand Down
32 changes: 17 additions & 15 deletions bin-proto-derive/src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use crate::attr::{Attrs, Tag};
use proc_macro2::TokenStream;
use syn::{spanned::Spanned, Error};

pub fn reads(fields: &syn::Fields, attrs: &Attrs) -> (TokenStream, TokenStream) {
pub fn reads(fields: &syn::Fields) -> (TokenStream, TokenStream) {
match *fields {
syn::Fields::Named(ref fields) => read_named_fields(fields, attrs),
syn::Fields::Unnamed(ref fields) => (quote!(), read_unnamed_fields(fields, attrs)),
syn::Fields::Named(ref fields) => read_named_fields(fields),
syn::Fields::Unnamed(ref fields) => (quote!(), read_unnamed_fields(fields)),
syn::Fields::Unit => (quote!(), quote!()),
}
}
Expand All @@ -21,15 +21,15 @@ pub fn writes(fields: &syn::Fields, self_prefix: bool) -> TokenStream {
}
}

fn read_named_fields(fields_named: &syn::FieldsNamed, attrs: &Attrs) -> (TokenStream, TokenStream) {
fn read_named_fields(fields_named: &syn::FieldsNamed) -> (TokenStream, TokenStream) {
let fields: Vec<_> = fields_named
.named
.iter()
.map(|field| {
let field_name = &field.ident;
let field_ty = &field.ty;

let read = read(field, attrs);
let read = read(field);

quote!(
let #field_name : #field_ty = #read?;
Expand All @@ -53,7 +53,7 @@ fn read_named_fields(fields_named: &syn::FieldsNamed, attrs: &Attrs) -> (TokenSt
)
}

fn read(field: &syn::Field, parent_attribs: &Attrs) -> TokenStream {
fn read(field: &syn::Field) -> TokenStream {
let attribs = match Attrs::try_from(field.attrs.as_slice()) {
Ok(attribs) => attribs,
Err(e) => return e.to_compile_error(),
Expand All @@ -62,10 +62,8 @@ fn read(field: &syn::Field, parent_attribs: &Attrs) -> TokenStream {
return e.to_compile_error();
};

let ctx_ty = parent_attribs.ctx_ty();

if let Some(field_width) = attribs.bits {
quote!(::bin_proto::BitFieldRead::<#ctx_ty>::read(__io_reader, __byte_order, __ctx, #field_width))
quote!(::bin_proto::BitFieldRead::read(__io_reader, __byte_order, __ctx, #field_width))
} else if attribs.flexible_array_member {
quote!(::bin_proto::FlexibleArrayMemberRead::read(
__io_reader,
Expand All @@ -75,20 +73,24 @@ fn read(field: &syn::Field, parent_attribs: &Attrs) -> TokenStream {
} else if let Some(tag) = attribs.tag {
match tag {
Tag::External(tag) => {
quote!(::bin_proto::TaggedRead::<_, #ctx_ty>::read(__io_reader, __byte_order, __ctx, #tag))
quote!(::bin_proto::TaggedRead::read(__io_reader, __byte_order, __ctx, #tag))
}
Tag::Prepend {
typ,
write_value: _,
} => {
quote!({
let __tag = ::bin_proto::ProtocolRead::<#ctx_ty>::read(__io_reader, __byte_order, __ctx)?;
::bin_proto::TaggedRead::<#typ, #ctx_ty>::read(__io_reader, __byte_order, __ctx, __tag)
let __tag = ::bin_proto::ProtocolRead::read(__io_reader, __byte_order, __ctx)?;
::bin_proto::TaggedRead::<#typ, _>::read(__io_reader, __byte_order, __ctx, __tag)
})
}
}
} else {
quote!(::bin_proto::ProtocolRead::<#ctx_ty>::read(__io_reader, __byte_order, __ctx))
quote!(::bin_proto::ProtocolRead::read(
__io_reader,
__byte_order,
__ctx
))
}
}

Expand Down Expand Up @@ -173,13 +175,13 @@ fn write_named_fields(fields_named: &syn::FieldsNamed, self_prefix: bool) -> Tok
quote!( #( #field_writers );* )
}

fn read_unnamed_fields(fields_unnamed: &syn::FieldsUnnamed, attrs: &Attrs) -> TokenStream {
fn read_unnamed_fields(fields_unnamed: &syn::FieldsUnnamed) -> TokenStream {
let field_initializers: Vec<_> = fields_unnamed
.unnamed
.iter()
.map(|field| {
let field_ty = &field.ty;
let read = read(field, attrs);
let read = read(field);

quote!(
{
Expand Down
8 changes: 7 additions & 1 deletion bin-proto-derive/src/codegen/trait_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ pub fn impl_trait_for(
| TraitImplType::UntaggedWrite
) {
trait_generics.push(if let Some(ctx) = attribs.ctx {
if let Some(ctx_generics) = attribs.ctx_generics {
generics.params.extend(ctx_generics);
}
quote!(#ctx)
} else {
let ident = syn::Ident::new("__Ctx", Span::call_site());
Expand All @@ -69,7 +72,10 @@ pub fn impl_trait_for(
attrs: Vec::new(),
ident: ident.clone(),
colon_token: None,
bounds: attribs.ctx_bounds.unwrap_or(Punctuated::new()),
bounds: attribs
.ctx_bounds
.map(|ctx_bounds| ctx_bounds.into_iter().collect())
.unwrap_or_default(),
eq_token: None,
default: None,
}));
Expand Down
4 changes: 2 additions & 2 deletions bin-proto-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn impl_for_struct(

let (impl_body, trait_type) = match protocol_type {
Operation::Read => {
let (reads, initializers) = codegen::reads(&strukt.fields, &attribs);
let (reads, initializers) = codegen::reads(&strukt.fields);
(
quote!(
#[allow(unused_variables)]
Expand Down Expand Up @@ -109,7 +109,7 @@ fn impl_for_enum(

match protocol_type {
Operation::Read => {
let read_variant = codegen::enums::read_variant_fields(&plan, &attribs);
let read_variant = codegen::enums::read_variant_fields(&plan);
let impl_body = quote!(
#[allow(unused_variables)]
fn read(__io_reader: &mut dyn ::bin_proto::BitRead,
Expand Down
2 changes: 1 addition & 1 deletion bin-proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ pub use self::tagged::{TaggedRead, UntaggedWrite};
/// }
///
/// #[derive(ProtocolRead, ProtocolWrite)]
/// #[protocol(ctx_bounds = CtxTrait)]
/// #[protocol(ctx_bounds(CtxTrait))]
/// pub struct WithCtx(NeedsCtx);
/// ```
#[cfg(feature = "derive")]
Expand Down
32 changes: 31 additions & 1 deletion bin-proto/tests/ctx.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
use bin_proto::{ByteOrder, ProtocolRead, ProtocolWrite};

trait Boolean {
fn set(&mut self);
}

impl Boolean for bool {
fn set(&mut self) {
*self = true
}
}

trait CtxTrait {
fn call(&mut self);
}
Expand All @@ -13,6 +23,18 @@ impl CtxTrait for CtxStruct {
}
}

#[derive(Debug)]
struct CtxStructWithGenerics<'a, T>(&'a mut T);

impl<'a, T> CtxTrait for CtxStructWithGenerics<'a, T>
where
T: Boolean,
{
fn call(&mut self) {
self.0.set()
}
}

#[derive(Debug)]
struct CtxCheck;

Expand Down Expand Up @@ -44,7 +66,15 @@ impl<Ctx: CtxTrait> ProtocolWrite<Ctx> for CtxCheck {
struct CtxCheckStructWrapper(CtxCheck);

#[derive(Debug, ProtocolRead, ProtocolWrite)]
#[protocol(ctx_bounds = CtxTrait)]
#[protocol(ctx = CtxStructWithGenerics<'a, bool>, ctx_generics('a))]
struct CtxCheckStructWrapperWithGenericsConcreteBool(CtxCheck);

#[derive(Debug, ProtocolRead, ProtocolWrite)]
#[protocol(ctx = CtxStructWithGenerics<'a, T>, ctx_generics('a, T: Boolean))]
struct CtxCheckStructWrapperWithGenerics(CtxCheck);

#[derive(Debug, ProtocolRead, ProtocolWrite)]
#[protocol(ctx_bounds(CtxTrait))]
struct CtxCheckTraitWrapper(CtxCheck);

#[test]
Expand Down

0 comments on commit 644026f

Please sign in to comment.